In [None]:
import math
import numpy as np
import os
import pandas as pd
import random
import time
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
from barbar import Bar
from datetime import datetime
from tensorboardX import SummaryWriter
from transformers import AutoTokenizer
from transformers import AutoModel
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler

In [None]:
log_dir=os.getenv('TENSORBOARD_DIR')

In [None]:
parameters = {
    'name': 'DistilBERT no training',
    'n_epochs': 5,
    'batch_size': 6,
    'val_split': 0.15,
    'device': 'cuda',
    'data_dir': '/jupyter/data/news/',
    'learning_rate': 5e-5,
    'margin': 3.0,
    'class_size': 200,
    'item_triplets': 50
}

def filter_parameters(parameters):
    res = dict(parameters)
    del res['data_dir']
    now = datetime.now()
    date_time = now.strftime("%d-%m-%Y %H-%M")
    res['start_time'] = date_time
    return res

In [None]:
def extract_tokens(row, tkn: transformers.AutoTokenizer):
    maxlen = 40
    title_tokens = tkn.encode_plus(
        row['title'],
        add_special_tokens=True,
        truncation=True,
        max_length=maxlen,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
    )
    t = torch.cat((title_tokens['input_ids'], title_tokens['attention_mask']), dim=0)
    return t

In [None]:
# implementation of extract_tokens with abstracts
def extract_tokens_abstracts(row, tkn: transformers.AutoTokenizer):
    maxlen = 40
    title_tokens = tkn.encode_plus(
        row['title'],
        add_special_tokens=True,
        truncation=True,
        max_length=maxlen,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
    )
    if type(row['abs']) is float:
        abstract_tokens = title_tokens
    else:
        abstract_tokens = tkn.encode_plus(
            row['abs'],
            add_special_tokens=True,
            truncation=True,
            max_length=maxlen,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
    t = torch.cat((title_tokens['input_ids'], title_tokens['attention_mask']), dim=0)
    a = torch.cat((abstract_tokens['input_ids'], abstract_tokens['attention_mask']), dim=0)
    r = torch.cat((t, a))
    return r

In [None]:
def extract_triplets(news_df: pd.DataFrame, tkn: transformers.AutoTokenizer, 
    limit: int, conn_num: int):
    random.seed(42)
    cats_dict = {}
    for _, row in news_df.iterrows():
        subcat = row['cats']
        if subcat in cats_dict:
            num = len(cats_dict[subcat])
            if num == limit:
                continue
            cats_dict[subcat].append(extract_tokens(row, tkn))
        else:
            cats_dict[subcat] = [extract_tokens(row, tkn)]
    
    categories = list(cats_dict.keys())
    
    res = []
    hits = []
    for class_idx in categories:
        cat_items = cats_dict[class_idx]
        for n_idx in range(len(cat_items)):
            if len(cat_items)-1 <= conn_num:
                for pos_idx in range(len(cat_items)):
                    if pos_idx == n_idx:
                        continue
                    # select negative sample index from other subcategories
                    neg_class_idx = random.randrange(0, len(categories))
                    if categories[neg_class_idx] == class_idx:
                        continue
                    neg_items = cats_dict[categories[neg_class_idx]]
                    neg_idx = random.randrange(0, len(neg_items))
                    # check that this combination was not selected before
                    comb = ((n_idx, class_idx), (pos_idx, class_idx), (neg_idx, categories[neg_class_idx]))
                    if comb in hits:
                        continue
                    item = (cat_items[n_idx], neg_items[neg_idx], cat_items[pos_idx])
                    res.append(item)
                    hits.append(comb)
                continue
            c = 0
            while True:
                if c == conn_num:
                    break
                # select negative sample index from other subcategories
                neg_class_idx = random.randrange(0, len(categories))
                if categories[neg_class_idx] == class_idx:
                    continue
                neg_items = cats_dict[categories[neg_class_idx]]
                neg_idx = random.randrange(0, len(neg_items))
                # select positive sample index from current subcategory
                pos_idx = random.randrange(0, len(cat_items))
                if pos_idx == n_idx:
                    continue
                # check that this combination was not selected before
                comb = ((n_idx, class_idx), (pos_idx, class_idx), (neg_idx, categories[neg_class_idx]))
                if comb in hits:
                    continue
                item = (cat_items[n_idx], neg_items[neg_idx], cat_items[pos_idx])
                res.append(item)
                hits.append(comb)
                c += 1
    return res

In [None]:
class NewsDataset(Dataset):
    def __init__(self, path, class_size, item_triplets):
        news_df = pd.read_csv(path, sep='\t', header=None,
                     names=['id','cats','subcat','title','abs','link','title-ent','abs-ent'])
        tkn = AutoTokenizer.from_pretrained("/jupyter/models/distilbert/")
        self.pairs = extract_triplets(news_df, tkn, class_size, item_triplets)
        self.df = news_df

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        return self.pairs[idx]

In [None]:
pairs_ds = NewsDataset(parameters['data_dir']+'news.tsv', parameters['class_size'], parameters['item_triplets'])
parameters['ds_size'] = len(pairs_ds)

In [None]:
ds_size = len(pairs_ds.pairs)
indices = list(range(ds_size))
split = int(np.floor(parameters['val_split'] * ds_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_len = len(train_indices) - len(train_indices)%parameters['batch_size']
train_indices = train_indices[:train_len]
val_len = len(val_indices) - len(val_indices)%parameters['batch_size']
val_indices = val_indices[:val_len]

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(pairs_ds, batch_size=parameters['batch_size'], sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(pairs_ds, batch_size=parameters['batch_size'], sampler=val_sampler)

In [None]:
class SiameseNet(nn.Module):
    def __init__(self):
        super(SiameseNet, self).__init__()
        self.tf_layer = AutoModel.from_pretrained('/jupyter/models/distilbert', torchscript=True)
        # uncomment these lines for freezing DistilBERT weights
        # for p in self.tf_layer.parameters():
        #     p.requires_grad = False
        self.bert_proc = nn.Sequential(
            nn.Linear(768, 512),
            nn.GELU(),
            nn.Linear(512, 256),
        )
    
    def forward(self, tokens):
        bert_title = self.tf_layer(tokens[:,0], tokens[:,1])
        out = self.bert_proc(bert_title[0][:,0,:])
        return out

In [None]:
# implementation of SiamseseNet with abstracts
class SiameseNetAbstracts(nn.Module):
    def __init__(self):
        super(SiameseNet, self).__init__()
        self.tf_layer = AutoModel.from_pretrained('/jupyter/models/distilbert', torchscript=True)
        # uncomment these lines for freezing DistilBERT weights
        # for p in self.tf_layer.parameters():
        #     p.requires_grad = False
        self.bert_proc = nn.Sequential(
            nn.Linear(1536, 1024),
            nn.GELU(),
            nn.Linear(1024, 512),
            nn.GELU(),
            nn.Linear(512, 256),
        )
    
    def forward(self, tokens):
        bert_title = self.tf_layer(tokens[:,0], tokens[:,1])
        bert_abs = self.tf_layer(tokens[:,2], tokens[:,3])
        concat = torch.cat((bert_title[0][:,0,:], bert_abs[0][:,0,:]), dim=-1)
        out = self.bert_proc(concat)
        out = self.bert_proc(bert_title[0][:,0,:])
        return out

In [None]:
class NewsSiamese():
    def __init__(self, device, report_step):
        self.model = SiameseNet()
        self.model = self.model.to(device)
        self.device = device
        self.report_step = report_step
        self.run_counter = 1
        self.train_counter = 1
        self.val_counter = 1
        
    def setup(self, crit, opt, writer):
        self.crit = crit
        self.opt = opt
        self.writer = writer
    
    def train_step(self, loader):
        self.model.train()
        running_loss = 0.0
        counter = 0
        total_loss = 0.0
        total_counter = 0
        for idx, triplets in enumerate(loader):
            self.opt.zero_grad()
            anchor, negative, positive = triplets
            anchor = anchor.to(self.device)
            negative = negative.to(self.device)
            positive = positive.to(self.device)
            anchor_out = self.model(anchor)
            negative_out = self.model(negative)
            positive_out = self.model(positive)
            loss = self.crit(anchor_out, negative_out, positive_out)
            loss.backward()
            self.opt.step()
            running_loss += loss.item()
            counter += 1
            total_loss += loss.item()
            total_counter += 1
            if idx % self.report_step == 0:
                self.writer.add_scalar("Train/Running loss", running_loss/counter, self.run_counter)
                self.writer.flush()
                self.run_counter += 1
                running_loss = 0.0
                counter = 0
        avg_loss = total_loss/total_counter
        self.writer.add_scalar("Train/Total loss", avg_loss, self.train_counter)
        self.train_counter += 1
        return total_loss
                
    def val_step(self, loader):
        self.model.eval()
        total_loss = 0.0
        counter = 0
        for idx, triplets in enumerate(loader):
            with torch.no_grad():
                anchor, negative, positive = triplets
                anchor = anchor.to(self.device)
                negative = negative.to(self.device)
                positive = positive.to(self.device)
                anchor_out = self.model(anchor)
                negative_out = self.model(negative)
                positive_out = self.model(positive)
                loss = self.crit(anchor_out, negative_out, positive_out)
                total_loss += loss.item()
                counter += 1
        avg_loss = total_loss/counter
        self.writer.add_scalar("Validation/Loss", avg_loss, self.val_counter)
        self.writer.flush()
        self.val_counter += 1
        return avg_loss

In [None]:
import torch.nn.functional as F

class TripletLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, negative, positive):
        neg_dist = F.pairwise_distance(anchor, negative, keepdim = True)
        pos_dist = F.pairwise_distance(anchor, positive, keepdim = True)
        loss = torch.mean(torch.clamp(pos_dist - neg_dist + self.margin, min=0.0))
        return loss

In [None]:
dev = torch.device(parameters['device'])
net = NewsSiamese(dev, 10)
crit = TripletLoss(margin=parameters['margin'])
opt = optim.AdamW(net.model.parameters(), lr=parameters['learning_rate'])
writer = SummaryWriter('/jupyter/runs/{}'.format(parameters['name']))
net.setup(crit, opt, writer)

In [None]:
for n in range(parameters['n_epochs']):
    print('\nEpoch {}'.format(n+1))
    print('Training...')
    train_loss = net.train_step(Bar(train_loader))
    print('\nValidating...')
    val_loss = net.val_step(Bar(validation_loader))
    net.writer.add_hparams(filter_parameters(parameters), 
        {'hparams/train_loss': train_loss, 'hparams/validation_loss': val_loss}, 'parameters', n+1)
    net.writer.flush()
    print('\n---------')