In [1]:
from google.colab import drive
drive.mount("/content/gdrive", force_remount=True)

Mounted at /content/gdrive


In [None]:
%pip install torch torchvision torchaudio
%pip install torchtext
%pip install pytorch-lightning-bolts

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
!ls

gdrive	sample_data


In [4]:
!ls "gdrive/MyDrive/685-project/CoLLT/"

augmentors.py	    data       main.ipynb  __pycache__
contrast_models.py  losses.py  models.py   README.md


In [5]:
%cd "gdrive/MyDrive/685-project/CoLLT/"

/content/gdrive/MyDrive/685-project/CoLLT


In [6]:
import torch
import os.path as osp
import losses as L
import augmentors as A
import models as M
import torch.nn.functional as F

from tqdm import tqdm
from torch.optim import Adam, AdamW
import datasets
from contrast_models import WithinEmbedContrast
from pl_bolts.optimizers import LinearWarmupCosineAnnealingLR
import numpy as np
device_name = 'cuda'
device = torch.device(device_name)

In [7]:
train_data, test_data = datasets.load_dataset('imdb', split =['train', 'test'], 
                                            cache_dir='./data/')
# train_data_dev, test_data_dev = train_data.select([0, 10, 20, 30, 40, 50]), test_data.select([0, 10, 20, 30, 40, 50])
num_dev = 1000
train_data_dev, test_data_dev = train_data.select(list(np.random.randint(len(train_data), size=num_dev))), test_data.select(list(np.random.randint(len(test_data), size=num_dev)))

Reusing dataset imdb (./data/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
from collections import Counter
Counter(train_data_dev['label'])

Counter({0: 482, 1: 518})

In [9]:
model_name='distilbert'
model, tokenizer = M.get_encoder(num_classes=2, model=model_name, device=device_name)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

In [10]:
print ('Start tokenization')
def tokenization(batched_text):
    return tokenizer(batched_text['text'], padding = 'max_length', truncation=True, max_length = 512)
train_data_dev = train_data_dev.map(tokenization, batched = True, batch_size = len(train_data_dev))
test_data_dev = test_data_dev.map(tokenization, batched = True, batch_size = len(test_data_dev))


Start tokenization


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [13]:
class Encoder(torch.nn.Module):
    def __init__(self, encoder, augmentor, input_dim=768, hidden_dim=768, output_dim=1536):
        super(Encoder, self).__init__()
        self.encoder = encoder
        self.augmentor = augmentor
        # self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        # self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
        # torch.nn.init.xavier_normal_(self.fc1.weight, gain=1.4)
        # torch.nn.init.xavier_normal_(self.fc2.weight, gain=1.4)

    def forward(self, x):
        aug1, aug2 = self.augmentor
        ids1, mask1 = aug1(x, device_name)
        ids2, mask2 = aug2(x, device_name)
        z1 = self.encoder(ids1, mask1)
        z2 = self.encoder(ids2, mask2)
        return z1, z2
    
    def predict(self, x):
        aug = A.Identity
        ids, mask = aug(x)
        z = self.encoder(ids, mask)
        return z

    # def project(self, x):
    #     return self.fc2(F.relu(self.fc1(x[1]))), self.fc2(F.relu(self.fc1(x[2])))


In [14]:
aug1 = A.Identity()
aug2 = A.Identity()

encoder_model = Encoder(encoder=getattr(model, model_name), augmentor=(aug1, aug2)).to(device)
contrast_model = WithinEmbedContrast(loss=L.BarlowTwins()).to(device)

optimizer = Adam(encoder_model.parameters(), lr=5e-4)
scheduler = LinearWarmupCosineAnnealingLR(
    optimizer=optimizer,
    warmup_epochs=400,
    max_epochs=4000)

In [15]:
def train(encoder_model, contrast_model, data, optimizer):
    encoder_model.train()
    optimizer.zero_grad()
    z1, z2 = encoder_model.forward(data)
    loss = contrast_model(z1.last_hidden_state[:,0,:], z2.last_hidden_state[:,0,:])
    if torch.isnan(loss):
      print ('ERROR')
      return
    loss.backward()
    optimizer.step()
    return loss.item()

In [16]:
epoch = 10
batch_size = 5
with tqdm(total=epoch, desc='(T)') as pbar:
    for epoch in range(1, epoch + 1):
        # For each batch of training data...
        num_batches = int(len(train_data_dev)/batch_size) + 1

        for i in range(num_batches):
            end_index = min(batch_size * (i+1), len(train_data_dev))

            batch = train_data_dev[i*batch_size:end_index]

            if len(batch) == 0: continue
            # print (batch)
            loss = train(encoder_model, contrast_model, batch, optimizer)
            scheduler.step()
            break
        pbar.set_postfix({'loss': loss})
        pbar.update()

(T): 100%|██████████| 10/10 [00:10<00:00,  1.03s/it, loss=264]


In [17]:
## TODO: get cross correlation matrix
## TODO: compute embeddings and then perform fine-tune

In [18]:
# Freeze the model parameters
for param in getattr(model, model_name).parameters():
    param.requires_grad = False

In [46]:
for name, param in model.named_parameters():
    print (name, param.requires_grad)

distilbert.embeddings.word_embeddings.weight False
distilbert.embeddings.position_embeddings.weight False
distilbert.embeddings.LayerNorm.weight False
distilbert.embeddings.LayerNorm.bias False
distilbert.transformer.layer.0.attention.q_lin.weight False
distilbert.transformer.layer.0.attention.q_lin.bias False
distilbert.transformer.layer.0.attention.k_lin.weight False
distilbert.transformer.layer.0.attention.k_lin.bias False
distilbert.transformer.layer.0.attention.v_lin.weight False
distilbert.transformer.layer.0.attention.v_lin.bias False
distilbert.transformer.layer.0.attention.out_lin.weight False
distilbert.transformer.layer.0.attention.out_lin.bias False
distilbert.transformer.layer.0.sa_layer_norm.weight False
distilbert.transformer.layer.0.sa_layer_norm.bias False
distilbert.transformer.layer.0.ffn.lin1.weight False
distilbert.transformer.layer.0.ffn.lin1.bias False
distilbert.transformer.layer.0.ffn.lin2.weight False
distilbert.transformer.layer.0.ffn.lin2.bias False
distilbe

In [20]:
batch_size = 50
optimizer = AdamW(model.parameters(),
                lr = 5e-5, # args.learning_rate - default is 5e-5
                eps = 1e-8 # args.adam_epsilon  - default is 1e-8
                )
epochs = 20

In [44]:
def get_validation_performance(model, val_set, batch_size):
    # Put the model in evaluation mode
    model.eval()

    # Tracking variables 
    total_eval_accuracy = 0
    total_eval_loss = 0

    num_batches = int(len(val_set)/batch_size) + 1

    total_correct = 0
    total = 0
    with tqdm(total=epoch, desc='(V)') as pbar:
      for i in range(num_batches):

        end_index = min(batch_size * (i+1), len(val_set))

        batch = val_set[i*batch_size:end_index]
        
        if len(batch['text']) == 0: continue

        input_id_tensors = torch.tensor(batch['input_ids'])
        input_mask_tensors = torch.tensor(batch['attention_mask'])
        label_tensors = torch.tensor(batch['label'])
        
        # Move tensors to the GPU
        b_input_ids = input_id_tensors.to(device)
        b_input_mask = input_mask_tensors.to(device)
        b_labels = label_tensors.to(device)
          
        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():        

          # Forward pass, calculate logit predictions.
          outputs = model(b_input_ids, 
                                  attention_mask=b_input_mask,
                                  labels=b_labels)
          loss = outputs.loss
          logits = outputs.logits
              
          # Accumulate the validation loss.
          total_eval_loss += loss.item()
          
          # Move logits and labels to CPU
          logits = logits.detach().cpu().numpy()
          label_ids = b_labels.to('cpu').numpy()

          # Calculate the number of correctly labeled examples in batch
          pred_flat = np.argmax(logits, axis=1).flatten()
          labels_flat = label_ids.flatten()
          # print (labels_flat)
          # print (pred_flat)
          num_correct = np.sum(pred_flat == labels_flat)
          total_correct += num_correct
          total += len(labels_flat)
          
        pbar.set_postfix({'val_accuracy': total_correct / total})
        pbar.update()
    # Report the final accuracy for this validation run.
    avg_val_accuracy = total_correct / len(val_set)
    return avg_val_accuracy



In [45]:
for epoch_i in range(0, epochs):
    # Perform one full pass over the training set.

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    # Reset the total loss for this epoch.
    total_train_loss = 0

    # Put the model into training mode.
    model.train()

    # For each batch of training data...
    num_batches = int(len(train_data_dev)/batch_size) + 1

    with tqdm(total=num_batches, desc='(T)') as pbar:
      for i in range(num_batches):
        end_index = min(batch_size * (i+1), len(train_data_dev))

        batch = train_data_dev[i*batch_size:end_index]

        if len(batch['text']) == 0: continue

        input_id_tensors = torch.tensor(batch['input_ids'])
        input_mask_tensors = torch.tensor(batch['attention_mask'])
        label_tensors = torch.tensor(batch['label'])

        # Move tensors to the GPU
        b_input_ids = input_id_tensors.to(device)
        b_input_mask = input_mask_tensors.to(device)
        b_labels = label_tensors.to(device)

        # Clear the previously calculated gradient
        model.zero_grad()        

        # Perform a forward pass (evaluate the model on this training batch).
        outputs = model(b_input_ids, 
                                attention_mask=b_input_mask, 
                                labels=b_labels)
        loss = outputs.loss
        logits = outputs.logits

        total_train_loss += loss.item()

        # Perform a backward pass to calculate the gradients.
        loss.backward()

        # Update parameters and take a step using the computed gradient.
        optimizer.step()

        pbar.set_postfix({'loss': loss.item()})
        pbar.update()
    # ========================================
    #               Validation
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our validation set. Implement this function in the cell above.
    print(f"Total loss: {total_train_loss}")
    val_acc = get_validation_performance(model, val_set=test_data_dev, batch_size=batch_size*2)
    print(f"Validation accuracy: {val_acc}")
    
print("")
print("Training complete!")



Training...


(T):  95%|█████████▌| 20/21 [00:34<00:01,  1.75s/it, loss=0.358]


Total loss: 8.081258922815323


(V): 100%|██████████| 10/10 [00:32<00:00,  3.29s/it, val_accuracy=0.834]


Validation accuracy: 0.834

Training...


(T):  95%|█████████▌| 20/21 [00:34<00:01,  1.74s/it, loss=0.352]


Total loss: 7.967085689306259


(V): 100%|██████████| 10/10 [00:32<00:00,  3.30s/it, val_accuracy=0.831]


Validation accuracy: 0.831

Training...


(T):  38%|███▊      | 8/21 [00:15<00:25,  1.96s/it, loss=0.351]


KeyboardInterrupt: ignored

In [21]:

## TODO: evaluate model
val_acc = get_validation_performance(model, val_set=test_data_dev, batch_size=batch_size)
print ('Best Validation accuracy: ', val_acc)


Best Validation accuracy:  1.0
