## Libraries

In [1]:
! pip -q install sentence_transformers
! pip -q install torchsummaryX
! pip -q install wandb

In [2]:
import os
import glob
import math
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import pandas as pd
import csv
import gzip
import random

from sentence_transformers import models, losses
from sentence_transformers import SentencesDataset, LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction
from transformers import BertModel, BertTokenizer, BertConfig
from torchsummaryX import summary
import random
from tqdm import tqdm
from transformers.models.bert.modeling_bert import BertEmbeddings
import wandb
import json
from typing import Optional

from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances
from scipy.stats import pearsonr, spearmanr

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

Device:  cuda


In [3]:
config = {
    'model_name': 'bert-base-uncased',
    'augmentation_type_1': 'shuffle',
    'augmentation_type_2': 'cutoff',
    'BATCH_SIZE': 96,
    'models': os.getcwd() + '/model',
    'data_path': './downstream',
    'cutoff_direction': 'column',
    'cutoff_rate': 0.2,
    'lr': 2e-5,
    'weight_decay': 1e-4,
    'temperature': 0.05,
    'LARGE_NUM': 1e9,
    'hidden_norm': True,
    'total_epochs': 20,
    'max_length': 64,
    'concatenated_out_len': 2304,
    'num_labels': 3,
    'mixup_rate': 0.4
}

## Data

In [4]:
! git clone https://github.com/yym6472/ConSERT.git

Cloning into 'ConSERT'...
remote: Enumerating objects: 475, done.[K
remote: Counting objects: 100% (475/475), done.[K
remote: Compressing objects: 100% (333/333), done.[K
remote: Total 475 (delta 166), reused 444 (delta 139), pack-reused 0[K
Receiving objects: 100% (475/475), 1.13 MiB | 5.65 MiB/s, done.
Resolving deltas: 100% (166/166), done.


In [None]:
! bash /content/ConSERT/data/get_transfer_data.bash

In [6]:
# !zip -q -r downstream.zip /content/downstream/

In [7]:
nli_dataset_path = 'datasets/AllNLI.tsv.gz'
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'
if not os.path.exists(nli_dataset_path):
    util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path)
if not os.path.exists(sts_dataset_path):
    util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)

  0%|          | 0.00/40.8M [00:00<?, ?B/s]

  0%|          | 0.00/392k [00:00<?, ?B/s]

In [8]:
try:
    os.mkdir(os.getcwd() + '/model')
except:
    pass

## Read STS

In [4]:
def load_sts12(need_label = False, use_all_unsupervised_texts=True, no_pair=True):
    dataset_names = ["MSRpar", "MSRvid", "SMTeuroparl", "surprise.OnWN", "surprise.SMTnews"]
    return load_sts(need_label, "12", dataset_names, no_pair=no_pair)
    
def load_sts13(need_label = False, use_all_unsupervised_texts=True, no_pair=True):
    dataset_names = ["headlines", "OnWN", "FNWN"]
    return load_sts(need_label, "13", dataset_names, no_pair=no_pair)

def load_sts14(need_label = False, use_all_unsupervised_texts=True, no_pair=True):
    dataset_names = ["images", "OnWN", "tweet-news", "deft-news", "deft-forum", "headlines"]
    return load_sts(need_label, "14", dataset_names, no_pair=no_pair)

def load_sts15(need_label = False, use_all_unsupervised_texts=True, no_pair=True):
    dataset_names = ["answers-forums", "answers-students", "belief", "headlines", "images"]
    return load_sts(need_label, "15", dataset_names, no_pair=no_pair)

def load_sts16(need_label = False, use_all_unsupervised_texts=True, no_pair=True):
    dataset_names = ["answer-answer", "headlines", "plagiarism", "postediting", "question-question"]
    return load_sts(need_label, "16", dataset_names, no_pair=no_pair)

def load_sts(need_label, year, dataset_names, no_pair=False):
    
    all_samples = []
    sts_data_path = f"{config['data_path']}/STS/STS{year}-en-test"
    
    for dataset_name in dataset_names:
        input_file = os.path.join(sts_data_path, f"STS.input.{dataset_name}.txt")
        label_file = os.path.join(sts_data_path, f"STS.gs.{dataset_name}.txt")
        sub_samples = load_paired_samples(need_label, input_file, label_file, no_pair=no_pair)
        all_samples.extend(sub_samples)
    
    return all_samples

def load_paired_samples(need_label, input_file, label_file, scale=5.0, no_pair=False):

    samples = []

    with open(input_file, "r") as f:
        input_lines = f.readlines()

    label_lines = [None]*len(input_lines)
    if label_file:
        with open(label_file, 'r') as labels:
            label_lines = labels.readlines()

    # Filtering out lines without labels
    if need_label:
        new_input_lines, new_label_lines = [], []
        for idx in range(len(label_lines)):
            label_lines[idx] = label_lines[idx].strip()
            if label_lines[idx]:
                new_input_lines.append(input_lines[idx])
                new_label_lines.append(label_lines[idx].strip())
        input_lines = new_input_lines
        label_lines = new_label_lines

    # Parsing text file for sentence and label
    for input_line, label_line in zip(input_lines, label_lines):
        sentences = input_line.split("\t")
            
        if len(sentences)==2:
            sent1, sent2 = sentences
        else:
            sent1, sent2 = sentences[0], None

        if need_label:
            samples.append(InputExample(texts=[sent1, sent2], label=float(label_line)/scale))

        else:
            if no_pair:
                samples.append(InputExample(texts=[sent1]))
                if sent2:
                    samples.append(InputExample(texts=[sent2]))
            else:
                samples.append(InputExample(texts=[sent1, sent2]))
    return samples


def load_stsbenchmark(need_label=False, use_all_unsupervised_texts=True, no_pair=True):

    all_samples = []
    if use_all_unsupervised_texts:
        splits = ["train", "dev", "test"]
    else:
        splits = ["test"]
    
    for split in splits:
        sts_benchmark_data_path = f"{config['data_path']}/STS/STSBenchmark/sts-{split}.csv"
        
        samples = []
        with open(sts_benchmark_data_path, "r") as f:
            lines = f.readlines()
        
            for line in lines:
                line = line.strip()
                _, _, _, _, label, sent1, sent2 = line.split("\t")
                if need_label:
                    samples.append(InputExample(texts=[sent1, sent2], label=float(label) / 5.0))
                else:
                    if no_pair:
                        samples.append(InputExample(texts=[sent1]))
                        samples.append(InputExample(texts=[sent2]))
                    else:
                        samples.append(InputExample(texts=[sent1, sent2]))
        all_samples.extend(samples)
    
    return all_samples

def load_sickr(need_label=False, use_all_unsupervised_texts=True, no_pair=True):
    
    all_samples = []
    if use_all_unsupervised_texts:
        splits = ["train", "trial", "test_annotated"]
    else:
        splits = ["test_annotated"]

    for split in splits:
        samples = []
        sick_data_path = f"{config['data_path']}/SICK/SICK_{split}.txt"
        
        with open(sick_data_path, "r") as f:
            lines = f.readlines()
        
        for line in lines[1:]:
            line = line.strip()
            _, sent1, sent2, label, _ = line.split("\t")
            
            if need_label:
                samples.append(InputExample(texts=[sent1, sent2], label=float(label) / 5.0))
            else:
                if no_pair:
                    samples.append(InputExample(texts=[sent1]))
                    samples.append(InputExample(texts=[sent2]))
                else:
                    samples.append(InputExample(texts=[sent1, sent2]))
        all_samples.extend(samples)
    
    return all_samples

In [5]:
def eval_sts(year, dataset_names):
    print(f"Evaluation on STS{year} dataset")
    sts_data_path = f"./downstream/STS/STS{year}-en-test"
    
    all_samples = []
    for dataset_name in dataset_names:
        input_file = os.path.join(sts_data_path, f"STS.input.{dataset_name}.txt")
        label_file = os.path.join(sts_data_path, f"STS.gs.{dataset_name}.txt")
        sub_samples = load_paired_samples(True, input_file, label_file)
        all_samples.extend(sub_samples)
    print(f"Loaded examples from STS{year} dataset, total {len(all_samples)} examples")
    return all_samples

def eval_sts12():
    dataset_names = ["MSRpar", "MSRvid", "SMTeuroparl", "surprise.OnWN", "surprise.SMTnews"]
    return eval_sts("12", dataset_names)
    
def eval_sts13():
    dataset_names = ["headlines", "OnWN", "FNWN"]
    return eval_sts("13", dataset_names)

def eval_sts14():
    dataset_names = ["images", "OnWN", "tweet-news", "deft-news", "deft-forum", "headlines"]
    return eval_sts("14", dataset_names)

def eval_sts15():
    dataset_names = ["answers-forums", "answers-students", "belief", "headlines", "images"]
    return eval_sts("15", dataset_names)

def eval_sts16():
    dataset_names = ["answer-answer", "headlines", "plagiarism", "postediting", "question-question"]
    return eval_sts("16", dataset_names)

def eval_stsbenchmark():
    print("Evaluation on STSBenchmark dataset")
    sts_benchmark_data_path = "./downstream/STS/STSBenchmark/sts-test.csv"
    with open(sts_benchmark_data_path, "r") as f:
        lines = [line.strip() for line in f if line.strip()]
    samples = []
    for line in lines:
        _, _, _, _, label, sent1, sent2 = line.split("\t")
        samples.append(InputExample(texts=[sent1, sent2], label=float(label) / 5.0))
    print(f"Loaded examples from STSBenchmark dataset, total {len(samples)} examples")
    return samples

def eval_sickr():
    print("Evaluation on SICK (relatedness) dataset")
    sick_data_path = "./downstream/SICK/SICK_test_annotated.txt"
    with open(sick_data_path, "r") as f:
        lines = [line.strip() for line in f if line.strip()]
    samples = []
    for line in lines[1:]:
        _, sent1, sent2, label, _ = line.split("\t")
        samples.append(InputExample(texts=[sent1, sent2], label=float(label) / 5.0))
    print(f"Loaded examples from SICK dataset, total {len(samples)} examples")
    
    return samples

In [6]:
encoder = BertModel.from_pretrained('bert-base-uncased', hidden_dropout_prob=0, attention_probs_dropout_prob=0)
encoder.model_max_len=config['max_length']
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', model_max_length=config['max_length'])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Dataset

In [7]:
class STSDatasetUnsupervised(Dataset):
    def __init__(self):
        super(STSDatasetUnsupervised, self).__init__()

        sts_data = []
        sts_data_12 = load_sts12()
        sts_data_13 = load_sts13()
        sts_data_14 = load_sts14()
        sts_data_15 = load_sts15()
        sts_data_16 = load_sts16()
        stsb = load_stsbenchmark()
        sickr = load_sickr()
        
        sts_data.extend(sts_data_12)
        sts_data.extend(sts_data_13)
        sts_data.extend(sts_data_14)
        sts_data.extend(sts_data_15)
        sts_data.extend(sts_data_16)
        sts_data.extend(stsb)
        sts_data.extend(sickr)

        self.dataset = sts_data
        self.length = len(self.dataset)

    def __len__(self):
        return self.length
        
    def __getitem__(self, index):
        return self.dataset[index]

    def collate(batch):
        num_texts = len(batch[0].texts)
        texts = []

        for example in batch:
            texts.append(example.texts[0])
        return tokenizer.batch_encode_plus(texts, padding='max_length', return_tensors='pt', truncation=True)

In [8]:
class STSDatasetUnsupervisedVal(Dataset):
    def __init__(self):
        super(STSDatasetUnsupervisedVal, self).__init__()
        self.dataset = []
        sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'
        with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
            reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
            for row in reader:
                if row['split'] == 'dev':
                    score = float(row['score']) / 5.0 #Normalize score to range 0 ... 1
                    self.dataset.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score))
        
        self.length = len(self.dataset)

    def __len__(self):
        return self.length
        
    def __getitem__(self, index):
        return self.dataset[index]

    def collate(batch):
        num_texts = len(batch[0].texts)
        texts1, texts2 = [], []
        labels = []

        for example in batch:
            texts1.append(example.texts[0])
            texts2.append(example.texts[1])
            labels.append(example.label)

        return (tokenizer.batch_encode_plus(texts1, padding='max_length', return_tensors='pt', truncation=True),
        tokenizer.batch_encode_plus(texts2, padding='max_length', return_tensors='pt', truncation=True),
        labels)

In [9]:
class STSDatasetUnsupervisedTest(Dataset):
    def __init__(self):
        super(STSDatasetUnsupervisedTest, self).__init__()

        sts_data = []
        sts_data_12 = eval_sts12()
        sts_data_13 = eval_sts13()
        sts_data_14 = eval_sts14()
        sts_data_15 = eval_sts15()
        sts_data_16 = eval_sts16()
        stsb = eval_stsbenchmark()
        sickr = eval_sickr()
        
        sts_data.extend(sts_data_12)
        sts_data.extend(sts_data_13)
        sts_data.extend(sts_data_14)
        sts_data.extend(sts_data_15)
        sts_data.extend(sts_data_16)
        sts_data.extend(stsb)
        sts_data.extend(sickr)

        self.dataset = sts_data
        self.length = len(self.dataset)

    def __len__(self):
        return self.length
        
    def __getitem__(self, index):
        return self.dataset[index]

    def collate(batch):
        num_texts = len(batch[0].texts)
        texts1, texts2 = [], []
        labels = []

        for example in batch:
            texts1.append(example.texts[0])
            texts2.append(example.texts[1])
            labels.append(example.label)

        return (tokenizer.batch_encode_plus(texts1, padding='max_length', return_tensors='pt', truncation=True),
        tokenizer.batch_encode_plus(texts2, padding='max_length', return_tensors='pt', truncation=True),
        labels)

In [10]:
unsupervised_dataset = STSDatasetUnsupervised()
unsupervised_dataset_val = STSDatasetUnsupervisedVal()
unsupervised_dataset_test = STSDatasetUnsupervisedTest()
train_dataloader = DataLoader(unsupervised_dataset, shuffle = True, batch_size=config['BATCH_SIZE'], collate_fn=STSDatasetUnsupervised.collate)
val_dataloader = DataLoader(unsupervised_dataset_val, shuffle = False, batch_size=config['BATCH_SIZE'], collate_fn=STSDatasetUnsupervisedVal.collate)
test_dataloader = DataLoader(unsupervised_dataset_test, shuffle = False, batch_size=config['BATCH_SIZE'], collate_fn=STSDatasetUnsupervisedTest.collate)

Evaluation on STS12 dataset
Loaded examples from STS12 dataset, total 3108 examples
Evaluation on STS13 dataset
Loaded examples from STS13 dataset, total 1500 examples
Evaluation on STS14 dataset
Loaded examples from STS14 dataset, total 3750 examples
Evaluation on STS15 dataset
Loaded examples from STS15 dataset, total 3000 examples
Evaluation on STS16 dataset
Loaded examples from STS16 dataset, total 1186 examples
Evaluation on STSBenchmark dataset
Loaded examples from STSBenchmark dataset, total 1379 examples
Evaluation on SICK (relatedness) dataset
Loaded examples from SICK dataset, total 4927 examples


In [11]:
print(unsupervised_dataset.length)
print(unsupervised_dataset_val.length)
print(unsupervised_dataset_test.length)

89192
1500
18850


In [12]:
for i, (x1, x2, y) in enumerate(val_dataloader):
    print(x1['input_ids'].shape, x2['input_ids'].shape, len(y))
    break

torch.Size([96, 64]) torch.Size([96, 64]) 96


In [13]:
for i, x in enumerate(train_dataloader):
    break

In [14]:
# import gc

# global feature_queue
# gc.collect()
# feature_queue = torch.nn.functional.normalize(torch.randn(size=(9600, 768), requires_grad=False), p=2.0, dim = 1).to(device)
# f_q_temp = torch.clone(feature_queue)
# def mean_pooling(model_output, attention_mask):
#     token_embeddings = model_output[0] #First element of model_output contains all token embeddings
#     input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#     return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# def append_to_queue(x, i):
#     global feature_queue
#     temp_embeddings = encoder(input_ids=x['input_ids'].to(device), attention_mask=x['attention_mask'].to(device))
#     temp_embeddings = mean_pooling(temp_embeddings, x['attention_mask'].to(device))
#     temp_embeddings = torch.nn.functional.normalize(temp_embeddings, p=2, dim=-1)
#     feature_queue = torch.concat([feature_queue[config['BATCH_SIZE'] - 1:], temp_embeddings], axis=0)

# for i, x in enumerate(train_dataloader):
#     if i == 0:
#         to_find_embeddings = encoder(input_ids=x['input_ids'].to(device), attention_mask=x['attention_mask'].to(device))
#         to_find_embeddings = mean_pooling(to_find_embeddings, x['attention_mask'].to(device))
#         to_find_embeddings = torch.nn.functional.normalize(to_find_embeddings, p=2, dim=-1)
#         continue
#     if i == 100:
#         break
#     print(i)

# support_similarities = torch.matmul(to_find_embeddings, feature_queue.T)
# indices = torch.argmax(support_similarities, dim=-1)
# similarities = torch.max(support_similarities, dim=1)

In [15]:
# print(indices)

In [16]:
# print(similarities.values.sum())

In [17]:
# support_similarities = torch.matmul(to_find_embeddings, f_q_temp.T)
# indices = torch.argmax(support_similarities, dim=-1)
# similarities = torch.max(support_similarities, dim=1)

In [18]:
# print(similarities.values.sum())

In [19]:
# print(unsupervised_dataset[20000], unsupervised_dataset[20001], unsupervised_dataset[20002], unsupervised_dataset[20003])

In [20]:
# print(unsupervised_dataset[3116])

## Loss Function

In [21]:
class NTXENT(nn.Module):
    def __init__(self, temperature, LARGE_NUM, hidden_norm):
        self.temperature = temperature
        self.LARGE_NUM = LARGE_NUM
        self.hidden_norm = hidden_norm
        self.batch_size = config['BATCH_SIZE']

    def __call__(self, out1, out2):
        if self.hidden_norm:
            out1 = torch.nn.functional.normalize(out1, p=2, dim=-1)
            out2 = torch.nn.functional.normalize(out2, p=2, dim=-1)

        out = torch.cat([out1, out2], dim=0)
        n_samples = len(out)

        # Full similarity matrix
        cov = torch.mm(out, out.t().contiguous())
        sim = torch.exp(cov / self.temperature)

        # Negative similarity
        mask = ~torch.eye(n_samples, device=device).bool()
        neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1)

        # Positive similarity :
        pos = torch.exp(torch.sum(out1 * out2, dim=-1) / self.temperature)
        pos = torch.cat([pos, pos], dim=0)

        loss = -torch.log(pos / neg).mean()

        return loss

## Data Augmentation

In [22]:
class Augmentation(nn.Module):
    def __init__(self, direction, rate):
        super(Augmentation, self).__init__()
        # self.bert_embeddings = BertEmbeddings(bert_config)
        self.direction = direction
        self.rate = rate

    def forward(self, sentence_feature, augmentation_type, embeddings):
        input_ids, token_type_ids, attention_mask = sentence_feature['input_ids'], sentence_feature['token_type_ids'], sentence_feature['attention_mask']
        bs, seq_len = input_ids.shape
        input_ids, token_type_ids, attention_mask = input_ids.to(device), token_type_ids.to(device), attention_mask.to(device)
        position_ids = torch.arange(seq_len).expand((bs, -1))[:, :seq_len].to(device=input_ids.device)

        if not self.training or augmentation_type is None:
            inputs_embeds = embeddings.word_embeddings(input_ids)
            token_type_embeddings = embeddings.token_type_embeddings(token_type_ids)
            position_embeddings = embeddings.position_embeddings(position_ids)

            embedding_output = inputs_embeds + position_embeddings + token_type_embeddings
            del inputs_embeds, token_type_embeddings, position_embeddings

            return embedding_output, attention_mask

        if augmentation_type == 'shuffle':
            # Shuffle
            position_ids = self._replace_position_ids(input_ids, position_ids, attention_mask)

            inputs_embeds = embeddings.word_embeddings(input_ids)
            token_type_embeddings = embeddings.token_type_embeddings(token_type_ids)
            position_embeddings = embeddings.position_embeddings(position_ids)

            embedding_output = inputs_embeds + position_embeddings + token_type_embeddings
            del inputs_embeds, token_type_embeddings, position_embeddings

            return embedding_output, attention_mask
            
        elif augmentation_type == 'cutoff':

            inputs_embeds = embeddings.word_embeddings(input_ids)
            token_type_embeddings = embeddings.token_type_embeddings(token_type_ids)
            position_embeddings = embeddings.position_embeddings(position_ids)

            embedding_output = inputs_embeds + position_embeddings + token_type_embeddings

            del inputs_embeds, token_type_embeddings, position_embeddings

            return self.apply_cutoff(embedding_output, attention_mask, self.direction, self.rate)

    def apply_cutoff(self, embedding_output, attention_mask, direction, rate):
        bs, seq_len, emb_size = embedding_output.shape
        cutoff_embeddings = []
        for batch_id in range(bs):
            sample_embedding = embedding_output[batch_id]
            sample_mask = attention_mask[batch_id]
            if direction == "row":
                num_dimensions = sample_mask.sum().int().item()  # number of tokens
                dim_index = 0
            elif direction == "column":
                num_dimensions = emb_size  # number of features
                dim_index = 1
            elif direction == "random":
                num_dimensions = sample_mask.sum().int().item() * emb_size
                dim_index = 0
            else:
                raise ValueError(f"direction should be either row or column, but got {direction}")

            num_cutoff_indexes = int(num_dimensions * rate)

            if num_cutoff_indexes < 0 or num_cutoff_indexes > num_dimensions:
                raise ValueError(f"number of cutoff dimensions should be in (0, {num_dimensions}), but got {num_cutoff_indexes}")
            
            indexes = list(range(num_dimensions))
            random.shuffle(indexes)
            cutoff_indexes = indexes[:num_cutoff_indexes]

            if direction == "random":
                sample_embedding = sample_embedding.reshape(-1)
            
            cutoff_embedding = torch.index_fill(sample_embedding, dim_index, torch.tensor(cutoff_indexes, dtype=torch.long).to(device), 0.0)
            
            if direction == "random":
                cutoff_embedding = cutoff_embedding.reshape(seq_len, emb_size)
                
            cutoff_embeddings.append(cutoff_embedding.unsqueeze(0))
        cutoff_embeddings = torch.cat(cutoff_embeddings, 0)

        assert cutoff_embeddings.shape == embedding_output.shape, (cutoff_embeddings.shape, embedding_output.shape)
        return cutoff_embeddings, attention_mask
    
    def _replace_position_ids(self, input_ids, position_ids, attention_mask):
        bs, seq_len = input_ids.shape
        
        shuffled_pid = []
        for batch_id in range(bs):
            sample_pid = position_ids[batch_id]
            sample_mask = attention_mask[batch_id]
            num_tokens = sample_mask.sum().int().item()
            indexes = list(range(num_tokens))
            random.shuffle(indexes)
            rest_indexes = list(range(num_tokens, seq_len))
            total_indexes = indexes + rest_indexes
            shuffled_pid.append(torch.index_select(sample_pid, 0, torch.tensor(total_indexes).to(device=input_ids.device)).unsqueeze(0))
        return torch.cat(shuffled_pid, 0)

In [23]:
encoder = BertModel.from_pretrained('bert-base-uncased')
encoder.model_max_len=config['max_length']

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Model

In [24]:
class Network(nn.Module):
    def __init__(self, input_dim_projection, hidden_dim_projection, output_dim_projection, hidden_dim_prediction, lamb, mixup):
        super(Network, self).__init__()

        ##########################CHANGE THIS WHEN AUGMENTATION IS APPLIED
        self.embedding = encoder.embeddings
        self.encoder_layers = encoder.encoder
        self.augmentation_module = Augmentation(config['cutoff_direction'], config['cutoff_rate'])
        self.feature_queue = torch.nn.functional.normalize(torch.randn(size=(9600, 768), requires_grad=False), p=2.0, dim = 1).to(device)
        self.lamb = lamb
        self.mixup = mixup
        

    def forward(self, x1, x2=None):
        if x2 is None:
            x2 = x1
        
        embedding_output1, attention_mask1 = self.augmentation_module(x1, config['augmentation_type_1'], self.embedding.to(device))
        embedding_output2, attention_mask2 = self.augmentation_module(x2, config['augmentation_type_2'], self.embedding.to(device))

        extended_attention_mask1 = attention_mask1[:, None, None, :]
        extended_attention_mask2 = attention_mask2[:, None, None, :]

        model_output_1 = self.encoder_layers(embedding_output1, extended_attention_mask1)
        model_output_2 = self.encoder_layers(embedding_output2, extended_attention_mask2)

        if self.mixup:
            model_output_1 = self.lamb * model_output_1  + (1-self.lamb) * model_output_2  
            attention_mask1 = torch.max(extended_attention_mask1, extended_attention_mask2)


        sentence_embedding_1 = self.mean_pooling(model_output_1, attention_mask1.to(device))
        sentence_embedding_2 = self.mean_pooling(model_output_2, attention_mask2.to(device))

        if not self.training:
            return sentence_embedding_1, sentence_embedding_2

        # sentence_projection_1 = self.projection_MLP(sentence_embedding_1)
        # sentence_projection_2 = self.projection_MLP(sentence_embedding_2)

        return sentence_embedding_1, sentence_embedding_2

    def update_queue(self, projections):
        projections = torch.nn.functional.normalize(projections, p=2, dim=-1)
        self.feature_queue = torch.concat([projections, self.feature_queue[:-config['BATCH_SIZE']]], axis=0)

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def find_nearest_neighbors(self, projections):
        projections = torch.nn.functional.normalize(projections, p=2, dim=-1)
        support_similarities = torch.matmul(projections, self.feature_queue.T)
        # nn_projections = torch.gather(self.feature_queue, dim=0, index=torch.argmax(support_similarities, dim=1).unsqueeze(-1))
        # return projections + (nn_projections - projections).detach()
        indices = torch.argmax(support_similarities, dim=-1)
        return projections + (self.feature_queue[indices] - projections).detach()

In [25]:
model = Network(768, 2048, 2048, 2048, config['mixup_rate'], False).to(device)

In [26]:
# summary(model, x1, None)
# del model

In [27]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
criterion = NTXENT(config['temperature'], config['LARGE_NUM'], config['hidden_norm'])

In [28]:
# import wandb

# wandb.login(key="1b4d95ae47c5d409e738db12aab18c42676f9367")

# run = wandb.init(
#     name = "setup_8", ### Wandb creates random run names if you skip this field, we recommend you give useful names
#     reinit=True, ### Allows reinitalizing runs when you re-run this cell
#     project="11711_hw4_NNCLR", ### Project should be created in your wandb account 
#     config=config ### Wandb Config for your run
# )

## Training

In [29]:
def train(epoch):
    batch_bar = tqdm(total=len(train_dataloader), dynamic_ncols=True, position=0, leave=False, desc='Train')
    total_loss = 0
    for i, x in enumerate(train_dataloader):
        model.train()
        optimizer.zero_grad()
        # projection1, projection2, prediction1, prediction2 = model(x, None)
        projection1, projection2 = model(x, None)
        if epoch == 1 and i <= 100:
            loss = criterion(projection1, projection2)
        else:
            neighbor1 = model.find_nearest_neighbors(projection1)
            neighbor2 = model.find_nearest_neighbors(projection2)

            loss = criterion(neighbor1, projection2) / 2 + criterion(neighbor2, projection1) / 2
        total_loss += loss.item()
        loss.backward()

        model.update_queue(projection1)
        model.update_queue(projection2)
        optimizer.step()

        batch_bar.set_postfix(
            epoch="{:d}".format(epoch),
            loss="{:.04f}".format(loss.item()))
        batch_bar.update()

        if i % 100 == 0 and i != 0:
            epc, esc, epd, esd = val()
            print("Epoch number: ", epoch)
            print("Pearson Cosine:", epc)
            print("Spearman Cosine:", esc)
            print("Pearson Dot:", epd)
            print("Spearman Dot:", esd)
            print()
    return total_loss / len(train_dataloader)

def val():
    model.eval()
    batch_bar = tqdm(total=len(val_dataloader), dynamic_ncols=True, position=0, leave=False, desc='Val')
    total_loss = 0
    total_eval_pearson_cosine = 0
    total_eval_spearman_cosine = 0
    total_eval_pearson_dot = 0
    total_eval_spearman_dot = 0
    for i, (x1, x2, labels) in enumerate(val_dataloader):
        with torch.no_grad():
            embedding1, embedding2 = model(x1, x2)

        embedding1 = embedding1.detach().cpu().numpy()
        embedding2 = embedding2.detach().cpu().numpy()

        cosine_scores = [1-i for i in paired_cosine_distances(embedding1, embedding2)]
        dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embedding1, embedding2)]

        eval_pearson_cosine, _ = pearsonr(np.array(labels), cosine_scores)
        eval_spearman_cosine, _ = spearmanr(np.array(labels), cosine_scores)
        total_eval_pearson_cosine += eval_pearson_cosine
        total_eval_spearman_cosine += eval_spearman_cosine

        eval_pearson_dot, _ = pearsonr(np.array(labels), dot_products)
        eval_spearman_dot, _ = spearmanr(np.array(labels), dot_products)
        total_eval_pearson_dot += eval_pearson_dot
        total_eval_spearman_dot += eval_spearman_dot

        batch_bar.set_postfix(
            eval_pearson_cosine="{:2f}".format(eval_pearson_cosine),
            eval_spearman_cosine="{:2f}".format(eval_spearman_cosine),
            eval_pearson_dot="{:2f}".format(eval_pearson_dot),
            eval_spearman_dot="{:2f}".format(eval_spearman_dot)
        )
        batch_bar.update()
    return total_eval_pearson_cosine / len(val_dataloader), total_eval_spearman_cosine / len(val_dataloader), total_eval_pearson_dot / len(val_dataloader), total_eval_spearman_dot / len(val_dataloader)

In [None]:
num_epochs = config['total_epochs']
# scaler = scaler = torch.cuda.amp.GradScaler()
# scheduler = ReduceLROnPlateau(optimizer, patience = 5, factor = 0.5, threshold = 0.05, verbose=True)
for i in range(1, num_epochs + 1):
    train_loss = train(i)

    # wandb.log({"train_loss": train_loss})
    # wandb.save("checkpoint" + str(i) + ".pth")

    print("Train loss is ", train_loss)

# run.finish()



Epoch number:  1
Pearson Cosine: 0.7248180430962432
Spearman Cosine: 0.723587082060752
Pearson Dot: 0.6807508982086413
Spearman Dot: 0.6757662367191649





Epoch number:  1
Pearson Cosine: 0.6961125592016473
Spearman Cosine: 0.6865922784314225
Pearson Dot: 0.6739447808345315
Spearman Dot: 0.6655508416760413





Epoch number:  1
Pearson Cosine: 0.7021303508300724
Spearman Cosine: 0.6890953878864094
Pearson Dot: 0.6911477059627862
Spearman Dot: 0.6769499312161625





Epoch number:  1
Pearson Cosine: 0.6852182515091464
Spearman Cosine: 0.6731375747715064
Pearson Dot: 0.6778032311987471
Spearman Dot: 0.6644344058298682





Epoch number:  1
Pearson Cosine: 0.6866441574219833
Spearman Cosine: 0.6689547359832153
Pearson Dot: 0.6801351393679508
Spearman Dot: 0.6651205735362945



Train:  58%|█████▊    | 541/930 [19:17<13:10,  2.03s/it, epoch=1, loss=0.3239]

In [51]:
import gc
torch.cuda.empty_cache()
gc.collect()

168