In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
import torch
import os
import transformers as ppb
from transformers import RobertaModel, AutoConfig
import warnings
from sentence_transformers import SentenceTransformer
from torch.utils.tensorboard import SummaryWriter
warnings.filterwarnings('ignore')
from xclib.data.data_utils import read_sparse_file
import copy
from transformers import get_scheduler,AdamW
from transformers import AutoTokenizer
import nlpaug.augmenter.word as naw
import nlpaug
import nltk
import xclib.evaluation.xc_metrics as xc_metrics
import xclib.data.data_utils as data_utils
import scipy.sparse as sp
from ngame.nns import exact_search
# nltk.download('averaged_perceptron_tagger')
# nltk.download('wordnet')

ModuleNotFoundError: No module named 'ngame'

In [2]:
datadir='/ecstorage/bert-opt/datasets'
# logdir = '/home/t-pbansal/logs'
# datadir='/home/t-pbansal/datasets'
datadir='%s/LF-AmazonTitles-131K'%datadir
tst_X_Y = read_sparse_file('%s/tst_X_Y.txt'%datadir).


In [5]:
np.array(tst_X_Y[:,1])

array(<134835x1 sparse matrix of type '<class 'numpy.float32'>'
	with 1 stored elements in Compressed Sparse Row format>, dtype=object)

In [2]:
class BertDataset(torch.utils.data.Dataset):

    def __init__(self, datadir,mapping, device,split):
        self.lbl_size = 131073
        self.datadir = datadir
        self.device=device
        self.size = 294805 if split =='trn' else 134835
        if split == 'trn':
            self.point_text_files = [x.strip() for x in open('%s/raw/trn_X.title.txt'%self.datadir).readlines()]
        else :
            self.point_text_files = [x.strip() for x in open('%s/raw/tst_X.title.txt'%self.datadir).readlines()]
        self.label_text_files = [x.strip() for x in open('%s/raw/Y.title.txt'%self.datadir).readlines()]
        self.mat_mapping = mapping
        self.mapping = mapping.nonzero()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        self.maxsize=32
        self.aug = naw.SynonymAug(aug_src='wordnet',aug_max=5)
        # aug = naw.WordEmbsAug(model_type='word2vec',model_path = './wiki-news-300d-1M.vec')
        temp = np.fromfile('%s/tst_filter_labels.txt'%self.datadir, sep=' ').astype(int)
        temp = temp.reshape(-1, 2).T
        self.tst_filter_mat = sp.coo_matrix((np.ones(temp.shape[1]), (temp[0], temp[1])), tst_X_Y.shape).tocsr()

    def __getitem__(self,index,augment=True):
        point_data,label_data = self.convert_joint(self.point_text_files[self.mapping[0][index]],self.label_text_files[self.mapping[1][index]],augment=augment)
#         point_data = self.convert(self.point_text_files[self.mapping[0][index]],augment=augment)
#         label_data = self.convert(self.label_text_files[self.mapping[1][index]],augment=augment)
        return (torch.Tensor(point_data['input_ids'][0]).to(device),
               torch.Tensor(label_data['input_ids'][0]).to(device),
               torch.Tensor(point_data['attention_mask'][0]).to(device),
               torch.Tensor(label_data['attention_mask'][0]).to(device))
    
    def convert_joint(self,textp,textl,augment=True):
        if augment : 
            textp=self.aug.augment(textp,n=1)
            textl=self.aug.augment(textl,n=1)

        combined = textp.split(' ')+textl.split(' ')
        split_point = int(np.random.uniform(1,len(combined)))
        textp = ' '.join(combined[:split_point])
        textl = ' '.join(combined[split_point:])
        
        return (self.tokenizer(textp,add_special_tokens = True,
                         truncation=True,return_tensors = 'np',
                         return_attention_mask = True,padding = 'max_length',max_length=self.maxsize),
                self.tokenizer(textl,add_special_tokens = True,
                         truncation=True,return_tensors = 'np',
                         return_attention_mask = True,padding = 'max_length',max_length=self.maxsize))
    
    def convert(self,text,augment=True):
        if augment : 
            text=self.aug.augment(text,n=1)
        return self.tokenizer(text,add_special_tokens = True,
                         truncation=True,return_tensors = 'np',
                         return_attention_mask = True,padding = 'max_length',max_length=self.maxsize)
    
    def get_embeds(self,model,batch_size,device):
        labels = self.convert(self.label_text_files,augment=False)
        points = self.convert(self.point_text_files,augment=False)
        num_labels = labels['input_ids'].shape[0]
        num_points = points['input_ids'].shape[0]
        with torch.no_grad():
            label_embeds = []
            for i in range(0,num_labels,batch_size):
                label_embeds.append(model.get_embed({'input_ids':torch.LongTensor(labels['input_ids'][i:i+batch_size]).to(device),
                                          'attention_mask':torch.LongTensor(labels['attention_mask'][i:i+batch_size]).to(device)}).cpu())
            label_embeds = torch.cat(label_embeds,dim=0)

            point_embeds = []
            for i in range(0,num_points,batch_size):
                point_embeds.append(model.get_embed({'input_ids':torch.LongTensor(points['input_ids'][i:i+batch_size]).to(device),
                                          'attention_mask':torch.LongTensor(points['attention_mask'][i:i+batch_size]).to(device)}).cpu())
            point_embeds = torch.cat(point_embeds,dim=0)
        return point_embeds,label_embeds

    def __len__(self):
        return len(self.mapping[0])

def collate_fn(batch):
    point_ids,label_ids,point_masks,label_masks = zip(*batch)
    return {'input_ids':torch.stack(point_ids,dim=0).long(),'attention_mask':torch.stack(point_masks,dim=0).long()},{'input_ids':torch.stack(label_ids,dim=0).long(),'attention_mask':torch.stack(label_masks,dim=0).long()}


class PrecEvaluator():
    def __init__(self, dataset, device,batch_size):
        self.K = 5
        self.metric = "P"
        self.dataset = dataset
        self.device = device
        self.batch_size = batch_size
        self.filter_mat = dataset.tst_filter_mat
        self.best_score = -9999999

    def __call__(self,model):
        xembs,yembs = self.dataset.get_embeds(model,self.batch_size,self.device)
        torch.cuda.empty_cache()
        es = exact_search({'data': yembs.cpu().numpy(), 'query': xembs.cpu().numpy(), 'K': 100, 'device': self.device})
        score_mat = es.getnns_gpu()
        if self.filter_mat is not None:
            self._filter(score_mat)
        res = self.printacc(score_mat, X_Y=self.dataset.mat_mapping, K=self.K)
        recall = xc_metrics.recall(score_mat, self.dataset.mat_mapping, k=100)*100
        print(f'Recall@100: {"%.2f"%recall[99]}')        
        score = res[str(self.K)][self.metric]
        return score
    
    def _filter(self,score_mat):
        temp = self.filter_mat.tocoo()
        score_mat[temp.row, temp.col] = 0
        del temp
        score_mat = score_mat.tocsr()
        score_mat.eliminate_zeros()
        return score_mat

            
    def printacc(self,score_mat, K = 5, X_Y = None, disp = True):
        if X_Y is None: X_Y = tst_X_Y

        acc = xc_metrics.Metrics(X_Y.tocsr().astype(np.bool), None)
        metrics = np.array(acc.eval(score_mat, K))*100
        df = pd.DataFrame(metrics)

        df.index = ['P', 'nDCG']
        
        df.columns = [str(i+1) for i in range(K)]
        if disp: display(df.round(2))
        return df

In [3]:
class BERTModel(torch.nn.Module):
    def __init__(self,gamma):
        super().__init__()
        self.encoder = SentenceTransformer('msmarco-distilbert-base-v3')
        self.rep_dim = self.encoder.get_sentence_embedding_dimension()
        self.hidden_dim = 1024
        self.target_encoder = copy.deepcopy(self.encoder).requires_grad_(False)
        self.predictor = torch.nn.Sequential(torch.nn.Linear(self.rep_dim,self.hidden_dim),
                                             torch.nn.ReLU(),torch.nn.Linear(self.hidden_dim,self.rep_dim))
        self.gamma = gamma
        
    def forward(self,x,y):
        flip = np.random.binomial(1,0.5)
        if (flip == 0):
            x_embeds = torch.nn.functional.normalize(self.predictor(self.encoder(x)['sentence_embedding']),dim=-1)
            y_embeds = torch.nn.functional.normalize(self.target_encoder(y)['sentence_embedding'],dim=-1).clone().detach()
        else:
            y_embeds = torch.nn.functional.normalize(self.predictor(self.encoder(y)['sentence_embedding']),dim=-1)
            x_embeds = torch.nn.functional.normalize(self.target_encoder(x)['sentence_embedding'],dim=-1).clone().detach()
        scores = x_embeds@y_embeds.T
        positives = torch.diag(scores).sum()
        negatives = scores.sum()-positives
        positives /= scores.shape[0]
        negatives /= ((scores.shape[0]-1)*(scores.shape[0]))
        self.update_target()
        return positives,negatives
    
    def get_embed(self,x):
        return torch.nn.functional.normalize(self.encoder(x)['sentence_embedding'],dim=-1)
        
    
    def update_target(self):
        target_dict = self.target_encoder.state_dict()
        online_dict = self.encoder.state_dict()
        for key in online_dict.keys():
            target_dict[key] = target_dict[key]*self.gamma + online_dict[key]*(1-self.gamma)
        self.target_encoder.load_state_dict(target_dict)

In [4]:
expname = 'WordNetAugCut'
datadir='/home/t-pbansal/datasets/LF-AmazonTitles-131K'
trn_X_Y = read_sparse_file('%s/trn_X_Y.txt'%datadir)
tst_X_Y = read_sparse_file('%s/tst_X_Y.txt'%datadir)
device = torch.device('cuda:0')
batch_size = 1024
max_epoch = 500
gamma = 0.995
iteration = 0

model = BERTModel(gamma = gamma).to(device)
results_dir = '/home/t-pbansal/logs/tensorboard/%s_%fgamma_%dbs'%(expname,gamma,batch_size)
writer = SummaryWriter(log_dir=results_dir)


train_dataloader = torch.utils.data.DataLoader(BertDataset(datadir=datadir,mapping=trn_X_Y,split='trn',device=device),batch_size=batch_size,drop_last=True,shuffle=True,collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(BertDataset(datadir=datadir,mapping=tst_X_Y,split='tst',device=device),batch_size=batch_size,drop_last=True,shuffle=True,collate_fn=collate_fn)

optimizer = AdamW(model.parameters(), lr=1e-4,weight_decay=1e-6)
num_training_steps = max_epoch * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=1000,
    num_training_steps=num_training_steps
)

evaluator = PrecEvaluator(test_dataloader.dataset,device,100)
best_score = float('inf')


In [None]:
for epoch in range(max_epoch):
    print ("starting epoch %d at iteration %d"%(epoch,iteration))
    for _,(x,y) in enumerate(train_dataloader):
        positives,negatives = model(x,y)
        loss = -positives
        exit()
        if (iteration %100  == 0):
            writer.add_scalar('train/positives',positives,iteration)
            writer.add_scalar('train/negatives',negatives,iteration)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        iteration += 1
    if(epoch %1 == 0):
        positives,negatives,count = 0,0,0
        with torch.no_grad():
            for _,(x,y) in enumerate(test_dataloader):
                positives_,negatives_ = model(x,y)
                positives += positives_
                negatives += negatives_
                count += 1
        score = evaluator.__call__(model)
        writer.add_scalar('val/P@5',score,iteration)
        writer.add_scalar('val/positives',positives/count,iteration)
        writer.add_scalar('val/negatives',negatives/count,iteration)
        if (score < best_score):
            best_score = score
            torch.save(model.state_dict(),os.path.join(results_dir,'checkpoint.pt'))
            print ('saved checkpoint for epoch %d'%epoch)

starting epoch 0 at iteration 0


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 10.60s, 0.0786 ms/pt


Unnamed: 0,1,2,3,4,5
P,7.39,5.48,4.36,3.58,3.01
nDCG,7.39,6.78,6.73,6.85,6.96


Recall@100: 12.45
saved checkpoint for epoch 0
starting epoch 1 at iteration 659


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 10.48s, 0.0778 ms/pt


Unnamed: 0,1,2,3,4,5
P,5.71,4.46,3.42,2.78,2.42
nDCG,5.71,5.43,5.29,5.36,5.52


Recall@100: 10.54
saved checkpoint for epoch 1
starting epoch 2 at iteration 1318


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 10.37s, 0.0769 ms/pt


Unnamed: 0,1,2,3,4,5
P,5.14,4.01,3.3,2.71,2.33
nDCG,5.14,4.82,4.91,5.01,5.14


Recall@100: 10.65
saved checkpoint for epoch 2
starting epoch 3 at iteration 1977


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 10.18s, 0.0755 ms/pt


Unnamed: 0,1,2,3,4,5
P,4.93,3.89,3.2,2.68,2.33
nDCG,4.93,4.64,4.69,4.83,4.99


Recall@100: 12.07
starting epoch 4 at iteration 2636


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 10.11s, 0.0750 ms/pt


Unnamed: 0,1,2,3,4,5
P,5.57,4.4,3.61,3.04,2.66
nDCG,5.57,5.26,5.31,5.48,5.67


Recall@100: 14.33
starting epoch 5 at iteration 3295


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 10.02s, 0.0743 ms/pt


Unnamed: 0,1,2,3,4,5
P,6.5,5.19,4.23,3.59,3.15
nDCG,6.5,6.21,6.25,6.47,6.72


Recall@100: 17.15
starting epoch 6 at iteration 3954


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 9.95s, 0.0738 ms/pt


Unnamed: 0,1,2,3,4,5
P,7.14,5.63,4.73,4.03,3.55
nDCG,7.14,6.79,6.95,7.2,7.48


Recall@100: 20.03
starting epoch 7 at iteration 4613


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 9.93s, 0.0736 ms/pt


Unnamed: 0,1,2,3,4,5
P,8.47,6.69,5.58,4.78,4.21
nDCG,8.47,8.07,8.25,8.59,8.93


Recall@100: 23.17
starting epoch 8 at iteration 5272


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 9.86s, 0.0732 ms/pt


Unnamed: 0,1,2,3,4,5
P,9.35,7.57,6.24,5.37,4.72
nDCG,9.35,9.11,9.24,9.63,10.0


Recall@100: 25.96
starting epoch 9 at iteration 5931


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 9.83s, 0.0729 ms/pt


Unnamed: 0,1,2,3,4,5
P,10.54,8.4,7.02,6.0,5.28
nDCG,10.54,10.21,10.47,10.87,11.28


Recall@100: 28.57
starting epoch 10 at iteration 6590


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 9.83s, 0.0729 ms/pt


Unnamed: 0,1,2,3,4,5
P,11.37,8.89,7.42,6.35,5.58
nDCG,11.37,10.87,11.12,11.54,11.97


Recall@100: 30.48
starting epoch 11 at iteration 7249


  0%|          | 0/264 [00:00<?, ?it/s]

Total time, time per point : 9.78s, 0.0725 ms/pt


Unnamed: 0,1,2,3,4,5
P,11.91,9.45,7.87,6.74,5.94
nDCG,11.91,11.5,11.76,12.21,12.68


Recall@100: 32.06
starting epoch 12 at iteration 7908
