All run logs are available in WANDB [project](https://wandb.ai/wosadeh/coursework)

# Common

In [None]:
! pip install -q sentencepiece transformers wandb persist-queue parsel

In [None]:
import pandas as pd
import numpy as np
import torch
import wandb
from torch import nn
import torch.nn.functional as F
from torchtext.legacy import data as text_data
from transformers import AutoModel, AutoTokenizer, BertModel, DistilBertModel, BertForSequenceClassification, DistilBertForSequenceClassification
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
from tqdm.auto import tqdm
from persistqueue import SQLiteQueue
from parsel import Selector

import pickle
import warnings
from os import path, remove, devnull
from shutil import copyfile, rmtree
from typing import Callable, Optional, List, Tuple, Collection, Union, MutableMapping
from copy import copy, deepcopy
from itertools import chain
import random
import sys

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
SEEDS = [42, 10, 173, 164, 34]

In [None]:
class HuggingFaceField(text_data.Field):
    def __init__(self, tokenizer):
        if tokenizer.eos_token is not None:
            eos_token = tokenizer.eos_token
        else:
            eos_token = tokenizer.sep_token
        super().__init__(
            tokenize=tokenizer.tokenize,
            use_vocab=False,
            pad_token=tokenizer.pad_token,
            init_token=tokenizer.cls_token,
            eos_token=eos_token,
            batch_first=True
            )
        self.tokenizer = tokenizer

    def numericalize(self, arr, device):
        arr = [self.tokenizer.convert_tokens_to_ids(x) for x in arr]
        return torch.tensor(arr, dtype=torch.long, device=device)


class SeqPairField(text_data.Field):
    def __init__(self, tokenizer):
        super().__init__(
            sequential=False,
            use_vocab=False,
            batch_first=True
            )
        self.tokenizer = tokenizer
    
    def process(self, batch, device=None):
        first_seq = [tup[0] for tup in batch]
        second_seq = [tup[1] for tup in batch]

        encoded_dict = self.tokenizer(
            first_seq, second_seq,
            add_special_tokens=True,
            padding='longest',
            return_tensors='pt',
            return_attention_mask=False

        )
        tokens = encoded_dict['input_ids']
        token_types = encoded_dict.get('token_type_ids', None)
        if device is not None:
            tokens = tokens.to(device)
            if token_types is not None:
                token_types = token_types.to(device)
        return (tokens, token_types)


class PairField(text_data.Field):
    def __init__(self, tokenizer):
        super().__init__(
            sequential=False,
            use_vocab=False,
            batch_first=True
            )
        self.tokenizer = tokenizer
    
    def process(self, batch, device=None):
        first_text = [tup[0] for tup in batch]
        # toks = map(self.tokenizer.tokenize, first_text)
        # max_len = max(map(len, toks))

        second_text = [tup[1] for tup in batch]
        # toks = map(self.tokenizer.tokenize, second_text)
        # max_len = max(max_len, max(map(len, toks)))
        # max_len += self.tokenizer.num_special_tokens_to_add(pair=False)

        first_enc_dict = self.tokenizer(
            first_text,
            add_special_tokens=True,
            # padding='max_length',
            padding='longest',
            return_tensors='pt',
            return_attention_mask=False,
            return_token_type_ids=False,
            # max_length=max_len,
        )
        first_tokens = first_enc_dict['input_ids']

        second_enc_dict = self.tokenizer(
            second_text,
            add_special_tokens=True,
            padding='longest',
            return_tensors='pt',
            return_attention_mask=False,
            return_token_type_ids=False,
        )
        second_tokens = second_enc_dict['input_ids']

        if device is not None:
            return (first_tokens.to(device), second_tokens.to(device))
            
        return (first_tokens, second_tokens)


class LabelField(text_data.Field):
    def __init__(self, is_target: bool = True):
        super().__init__(
            use_vocab=False,
            sequential=False,
            tokenize=lambda x: x,
            batch_first=True,
            is_target=is_target
            )
        
    def numericalize(self, arr, device):
        arr = [int(item) for item in arr]
        return torch.tensor(arr, dtype=torch.long, device=device)        

In [None]:
class TransformerWrapper(nn.Module):
    """
    HuggingFace BERT model wrapper,
    which produces [CLS] embedding from several last transformer layers
    """
    def __init__(self, model_name: str,
                 freeze: Union[bool, int] = False,
                 aggregate_n_last_hidden_layers: int = 1,
                 aggregate_mode: str = 'mean',
                 revision = None):
        
        assert aggregate_n_last_hidden_layers >= 1
        aggregate_mode = aggregate_mode.lower()
        assert aggregate_mode in ('mean', 'sum', 'cat', 'concat', 'concatenate')
        super(TransformerWrapper, self).__init__()

        if revision is None:
            self.model = AutoModel.from_pretrained(model_name)
        else:
            self.model = AutoModel.from_pretrained(model_name, revision=revision)
        self.hidden_size = self.model.config.hidden_size

        if aggregate_mode in ('cat', 'concat', 'concatenate'):
            self.output_size = aggregate_n_last_hidden_layers * self.hidden_size
        else:
            self.output_size = self.hidden_size

        if freeze:
            if isinstance(freeze, bool):
                freeze = 10**5
        else:
            freeze = 0

        if isinstance(self.model, BertModel):
            layers = self.model.encoder.layer
        elif isinstance(self.model, DistilBertModel):
            layers = self.model.transformer.layer

        for param in self.model.embeddings.parameters():
            param.requires_grad = False

        for layer in layers[:freeze]:
            for param in layer.parameters():
                param.requires_grad = False

        self.n_agg = aggregate_n_last_hidden_layers
        self.agg_mode = aggregate_mode

    def forward(self, x: torch.LongTensor,
                mask: Union[torch.BoolTensor, torch.Tensor, None] = None,
                token_type_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
        if mask is None:
            mask = torch.ones_like(x, dtype=torch.float, device=x.device).detach()
        mask = mask.float()

        embeddings = self.model(
            x,
            attention_mask=mask,
            token_type_ids=token_type_ids,
            output_hidden_states=True if self.n_agg > 1 else False,
            return_dict=True
            )
        if self.n_agg == 1:
            if isinstance(self.model, BertModel):
                cls_embeddings = embeddings['pooler_output']
            else:
                cls_embeddings = embeddings['last_hidden_state'][:, 0, :]
        else:
            hidden = embeddings['hidden_states']
            i = len(hidden) - self.n_agg
            cls_hidden = [layer_out[:, 0, :] for layer_out in hidden[i:]]

            if self.agg_mode == 'mean':
                cls_hidden = torch.stack(cls_hidden, dim=2)
                cls_embeddings = torch.mean(cls_hidden, dim=2)
            elif self.agg_mode == 'sum':
                cls_hidden = torch.stack(cls_hidden, dim=2)
                cls_embeddings = torch.sum(cls_hidden, dim=2)
            elif self.agg_mode in ('cat', 'concat', 'concatenate'):
                cls_embeddings = torch.cat(cls_hidden, dim=1)
        
        return cls_embeddings

In [None]:
class TransformerCls(TransformerWrapper):
    def __init__(self, model_name: str,
                 num_classes: int,
                 dropout: float = 0.,
                 freeze: Union[bool, int] = False,
                 aggregate_n_last_hidden_layers: int = 1,
                 aggregate_mode: str = 'mean',
                 revision=None):
        super(TransformerCls, self).__init__(
            model_name, freeze,
            aggregate_n_last_hidden_layers,
            aggregate_mode,
            revision
            )
        self.clf = nn.Sequential(
            nn.Linear(self.output_size, 256),
            nn.LeakyReLU(0.01),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
        torch.nn.init.xavier_normal_(
            self.clf[0].weight,
            nn.init.calculate_gain('leaky_relu', 0.01)
        )
        torch.nn.init.xavier_normal_(self.clf[3].weight)


    def forward(self, x: torch.LongTensor,
                mask: Union[torch.BoolTensor, torch.Tensor, None] = None,
                token_type_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
        cls_emb = super(TransformerCls, self).forward(
            x,
            mask,
            token_type_ids=token_type_ids
        )
        return self.clf(cls_emb)

In [None]:
class SiameseClf(TransformerWrapper):
    def __init__(self, model_name: str,
                 num_classes: int,
                 dropout: float = 0.,
                 freeze: Union[bool, int] = False,
                 aggregate_n_last_hidden_layers: int = 1,
                 aggregate_mode: str = 'mean',
                 revision=None):
        super(SiameseClf, self).__init__(
            model_name, freeze,
            aggregate_n_last_hidden_layers,
            aggregate_mode,
            revision
            )
        self.clf = nn.Sequential(
            nn.Linear(2 * self.output_size, 256),
            nn.LeakyReLU(0.01),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
        torch.nn.init.xavier_normal_(
            self.clf[0].weight,
            nn.init.calculate_gain('leaky_relu', 0.01)
        )
        torch.nn.init.xavier_normal_(self.clf[3].weight)


    def forward(self,
                x1: torch.LongTensor, x2: torch.LongTensor,
                mask1: Union[torch.BoolTensor, torch.Tensor, None] = None,
                mask2: Union[torch.BoolTensor, torch.Tensor, None] = None,
                return_dict: bool = False) -> Union[torch.Tensor, MutableMapping[str, torch.Tensor]]:
        cls_emb_1 = super(SiameseClf, self).forward(
            x1,
            mask1,
            token_type_ids=None
        )
        cls_emb_2 = super(SiameseClf, self).forward(
            x2,
            mask2,
            token_type_ids=None
        )

        emb = torch.stack([cls_emb_1, cls_emb_2], dim=1)
        logits = self.clf(emb.view(emb.size(0), -1))
        if return_dict:
            return {
                'logits': logits,
                'embeddings': emb
            }
        else:
            return logits

In [None]:
class Trainer:
    log_every = 30

    def __init__(self, pad_index: Optional[int] = None, silent: bool = False) -> None:
        self.global_step = 0
        self.pad_index = pad_index
        self.cur_epoch = None
        self.silent = silent

    def train_step(
        self,
        model: nn.Module, batch,
        criterion, optimizer,
        it: Optional[int] = None
        ) -> MutableMapping[str, Optional[float]]:

        tokens, token_type_ids = batch.text
        if self.pad_index is None:
            mask = None
        else:
            mask = (tokens != self.pad_index).float()
        output = model(tokens, mask, token_type_ids)
        loss = criterion(output, batch.label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.argmax(output, dim=-1)
        acc = (preds == batch.label).type(torch.float).mean()
        return {
            'loss': loss.item(),
            'accuracy': acc.item()
        }
    
    def val_step(
        self,
        model: nn.Module, batch,
        criterion = None,
        it: Optional[int] = None) -> MutableMapping[str, Optional[float]]:
        
        tokens, token_type_ids = batch.text
        if self.pad_index is None:
            mask = None
        else:
            mask = (tokens != self.pad_index).float()
        output = model(tokens, mask, token_type_ids)

        preds = torch.argmax(output, dim=-1)
        acc = (preds == batch.label).type(torch.float).mean()

        step_log = {'accuracy': acc.item()}
        if criterion is not None:
            step_log['loss'] = criterion(output, batch.label).item()

        return step_log
    
    def train(self, model: nn.Module, train_iterator, val_iterator, criterion, optimizer, total_epochs):
        best_acc = -float('inf')
        best_model_wts = None

        self.global_step = 0
        if self.silent:
            pbar = range(total_epochs)
        else:
            pbar = tqdm(
                range(total_epochs),
                unit='Epoch', desc='Total progress',
                position=0, leave=True
            )
        for epoch in pbar:
            self.cur_epoch = epoch
            epoch_log = self.train_epoch(model, train_iterator, criterion, optimizer)
            msg = f'Epoch {epoch} is finished.\nTraining metrics:\n'
            for metric_name, val in epoch_log.items():
                msg += f'\t{metric_name}: {val:.4f}\n'

            print(msg)
            if val_iterator is not None:
                val_log = self.validate(model, val_iterator, criterion)

                msg = 'Validation metrics:\n'
                log = {}
                for metric_name, val in val_log.items():
                    msg += f'\t{metric_name}: {val:.4f}\n'
                    log[f'val_{metric_name}'] = val
                print(msg)             
                wandb.log(log)

                epoch_acc = val_log.get('accuracy', -float('inf'))
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = deepcopy(model.state_dict())
            
        self.cur_epoch = None
        if best_model_wts is not None:
            wandb.run.summary["val_accuracy"] = best_acc
            model.load_state_dict(best_model_wts)

    @torch.no_grad()
    def validate(
        self,
        model: nn.Module, iterator,
        criterion = None
        ) -> MutableMapping[str, Optional[float]]:
        running_loss = 0.
        running_acc = 0.
        loss_cnt = 0
        acc_cnt = 0

        model.eval()
        if self.silent:
            pbar = iterator
        else:
            pbar = tqdm(
                iterator, unit='batch',
                position=1, leave=False,
                desc='Validation phase'
            )
        for it, batch in enumerate(pbar):
            it_log = self.val_step(model, batch, criterion, it)

            if 'loss' in it_log:
                running_loss += it_log['loss']
                loss_cnt += 1
            if 'accuracy' in it_log:
                running_acc += it_log['accuracy']
                acc_cnt += 1
        
        epoch_log = {}
        if loss_cnt > 0:
            epoch_log['loss'] = running_loss / loss_cnt
        if acc_cnt > 0:
            epoch_log['accuracy'] = running_acc / acc_cnt
        return epoch_log
        
    def train_epoch(self,
                    model: nn.Module, iterator,
                    criterion, optimizer
                    ) -> MutableMapping[str, Optional[float]]:
        running_loss = 0.
        running_acc = 0.
        loss_cnt = 0
        acc_cnt = 0

        model.train()
        if self.silent:
            pbar = iterator
        else:
            pbar = tqdm(
                iterator,
                unit='batch',
                desc='Training phase',
                position=1, leave=False
            )
        for it, batch in enumerate(pbar):
            it_log = self.train_step(model, batch, criterion, optimizer, it)

            if 'loss' in it_log:
                running_loss += it_log['loss']
                loss_cnt += 1
            if 'accuracy' in it_log:
                running_acc += it_log['accuracy']
                acc_cnt += 1

            self.global_step += batch.label.size(0)
            if it % self.log_every == self.log_every - 1:
                it_log['epoch'] = self.cur_epoch
                wandb.log(it_log, step=self.global_step)

        epoch_log = {}
        if loss_cnt > 0:
            epoch_log['loss'] = running_loss / loss_cnt
        if acc_cnt > 0:
            epoch_log['accuracy'] = running_acc / acc_cnt
        return epoch_log

In [None]:
LABEL = LabelField()
def get_val_dataset(tokenizer, path: str = 'dev.tsv', shuffle: bool = True, field_cls=SeqPairField):
    pair_field = field_cls(tokenizer)

    df = pd.read_csv(path, sep='\t')
    if shuffle:
        df = df.sample(frac=1.)

    examples = []
    for _, row in df.dropna().iterrows():
        ex = text_data.Example.fromlist(
            [(row['left_text'], row['right_text']), row['class']],
            [('text', pair_field), ('label', LABEL)]
            )
        examples.append(ex)

    dataset = text_data.Dataset(
        deepcopy(examples),
        {
            'text': pair_field,
            'label': LABEL
        }
    )
    return dataset

# Paraphraser

In [None]:
class ParaphraserTrainer(Trainer):
    def __init__(self, pad_index: Optional[int] = None, silent: bool = False, p: float = 0.5) -> None:
        super().__init__(pad_index=pad_index, silent=silent)
        self.prob = p
        self._semi_dist_v = torch.Tensor([1 - p, p]).detach()

    def train_step(
        self,
        model: nn.Module, batch,
        criterion, optimizer,
        it: Optional[int] = None
        ) -> MutableMapping[str, Optional[float]]:

        # Convert classes to bernoilli dist tensor
        # "2" (means the same) class has distribution [0, 1]
        # "0" (different meaning) class has distribution [1, 0]
        # "1" (close, but not exactly the same) class has distribution [1 - p, p]
        label_dist = torch.full(
            (batch.label.size(0), 2),
            0.,
            dtype=torch.float,
        )
        label_dist[torch.where(batch.label == 2)[0], 1] = 1.
        label_dist[torch.where(batch.label == 0)[0], 0] = 1.
        label_dist[torch.where(batch.label == 1)[0]] = self._semi_dist_v
        label_dist = label_dist.to(batch.label.device)

        tokens, token_type_ids = batch.text
        if self.pad_index is None:
            mask = None
        else:
            mask = (tokens != self.pad_index).float()
        output = model(tokens, mask, token_type_ids)

        log_prob = F.log_softmax(output, dim=1)
        # NLL loss with soft targets
        loss = torch.mean(torch.sum(-label_dist * log_prob, 1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.argmax(output, dim=-1)
        hard_labels = torch.argmax(label_dist, dim=1)
        acc = (preds == hard_labels).type(torch.float).mean()
        return {
            'loss': loss.item(),
            'accuracy': acc.item()
        }


In [None]:
model_name = 'DeepPavlov/rubert-base-cased'
BATCH_SIZE = 32

In [None]:
! wget http://www.paraphraser.ru/download/get?file_id=1 -O tmp.zip
! unzip tmp.zip
! rm tmp.zip

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
unk_index = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
cls_index = tokenizer.convert_tokens_to_ids(tokenizer.cls_token)
sep_index = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)

In [None]:
PAIR = SeqPairField(tokenizer)
LABEL = LabelField()

with open('paraphrases.xml') as f:
    selector = Selector(text=f.read())
examples = []
for pair in selector.xpath('//corpus/paraphrase'):
    left_text = pair.xpath('value[@name="text_1"]/text()').get()
    right_text = pair.xpath('value[@name="text_2"]/text()').get() 
    label = int(pair.xpath('value[@name="class"]/text()').get()) + 1
    ex = text_data.Example.fromlist(
        [(left_text, right_text), label],
        [('text', PAIR), ('label', LABEL)]
        )
    examples.append(ex)

train_dataset = text_data.Dataset(
    deepcopy(examples),
    {
        'text': PAIR,
        'label': LABEL
    }
)

train_iterator = text_data.BucketIterator(
    train_dataset,
    BATCH_SIZE,
    sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
    device=device,
    train=True
)
val_iterator = text_data.BucketIterator(
    get_val_dataset(tokenizer),
    BATCH_SIZE,
    sort_key=lambda x:len(x.text[0].split()) + len(x.text[1].split()),
    device=device,
    train=False
)

In [None]:
config = {
    'HuggingFace model': model_name,
    'LR': 5e-5,
    'Epochs': 10,
    'Dropout rate': 0.6,
    'Freezed layers': 10,
    'Semi-positive class prob': 0.5,
}

In [None]:
def run_trial(seed, tokenizer, trial_config, project_name: str = 'coursework'):
    pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    cudnn_det_cache = None
    if torch.cuda.is_available():
        cudnn_det_cache = torch.backends.cudnn.deterministic
        torch.backends.cudnn.deterministic = True

    train_iterator = text_data.BucketIterator(
        train_dataset,
        BATCH_SIZE,
        sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
        device=device,
        train=True
    )
    val_iterator = text_data.BucketIterator(
        get_val_dataset(tokenizer),
        BATCH_SIZE,
        sort_key=lambda x:len(x.text[0].split()) + len(x.text[1].split()),
        device=device,
        train=False
    )

    with wandb.init(project=project_name, config=trial_config, group='Paraphraser', save_code=False):
        # Init model
        model = TransformerCls(
            model_name,
            2,
            dropout=wandb.config['Dropout rate'],
            freeze=wandb.config['Freezed layers'],
        ).to(device)

        criterion = nn.CrossEntropyLoss()
        optim = torch.optim.Adam(model.parameters(), lr=wandb.config['LR'])

        # Train model
        ParaphraserTrainer(
            pad_index=pad_index,
            silent=True,
            p=wandb.config['Semi-positive class prob']
        ).train(
            model,
            train_iterator, val_iterator,
            criterion, optim,
            total_epochs=wandb.config['Epochs']
        )

        # Test scores
        test_iterator = text_data.BucketIterator(
            get_val_dataset(tokenizer, path='test.tsv', shuffle=True),
            BATCH_SIZE,
            sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
            device=device,
            shuffle=False
        )

        labels = []
        preds = []
        scores = []

        model.eval()
        with torch.no_grad():
            for batch in test_iterator:
                tokens, token_type_ids = batch.text
                if pad_index is None:
                    mask = None
                else:
                    mask = (tokens != pad_index).float()
                output = model(tokens, mask, token_type_ids)

                scores += torch.softmax(output, dim=1)[:, 1].cpu().tolist()
                batch_preds = torch.argmax(output, dim=-1)
                labels += batch.label.cpu().tolist()
                preds += batch_preds.cpu().tolist()

        test_acc = accuracy_score(labels, preds)
        wandb.run.summary['Test/Accuracy'] = test_acc
        print(f'Test set Accuracy: {test_acc:.4f}')

        test_f1 = f1_score(labels, preds)
        wandb.run.summary['Test/F1'] = test_f1
        print(f'Test set F1 Score: {test_f1:.4f}')

        test_roc_auc = roc_auc_score(labels, scores)
        wandb.run.summary['Test/ROC-AUC'] = test_roc_auc
        print(f'Test set ROC-AUC Score: {test_roc_auc:.4f}')
     if cudnn_det_cache is not None:
        torch.backends.cudnn.deterministic = cudnn_det_cache

In [None]:
for trial_n, s in enumerate(tqdm(SEEDS, unit='trial')):
    print(f'Running trial #{trial_n + 1}')
    run_trial(s, tokenizer, config)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

Running trial #1


Epoch 0 is finished.
Training metrics:
	loss: 0.5957
	accuracy: 0.6364

Validation metrics:
	loss: 0.6354
	accuracy: 0.5573

Epoch 1 is finished.
Training metrics:
	loss: 0.5203
	accuracy: 0.6671

Validation metrics:
	loss: 0.5965
	accuracy: 0.5998

Epoch 2 is finished.
Training metrics:
	loss: 0.5011
	accuracy: 0.6766

Validation metrics:
	loss: 0.4838
	accuracy: 0.7578

Epoch 3 is finished.
Training metrics:
	loss: 0.4813
	accuracy: 0.6891

Validation metrics:
	loss: 0.5546
	accuracy: 0.6597

Epoch 4 is finished.
Training metrics:
	loss: 0.4706
	accuracy: 0.7015

Validation metrics:
	loss: 0.4687
	accuracy: 0.7726

Epoch 5 is finished.
Training metrics:
	loss: 0.4547
	accuracy: 0.7075

Validation metrics:
	loss: 0.4571
	accuracy: 0.7812

Epoch 6 is finished.
Training metrics:
	loss: 0.4395
	accuracy: 0.7127

Validation metrics:
	loss: 0.4395
	accuracy: 0.8099

Epoch 7 is finished.
Training metrics:
	loss: 0.4281
	accuracy: 0.7136

Validation metrics:
	loss: 0.4938
	accuracy: 0.7318



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss,0.38093
accuracy,0.78125
epoch,9.0
_runtime,532.0
_timestamp,1621530952.0
_step,71758.0
val_loss,0.45962
val_accuracy,0.8099
Test/Accuracy,0.75167
Test/F1,0.7286


0,1
loss,█▅▆▄▇▅▇▄▆▃▅▇▄▃▄▄▄▆▂▅▃▂▄▃▃▄▂▃▄▃▄▃▃▆▃▂▃▁▃▂
accuracy,▄▄▃▅▆▅▁▅▃▃▅▄▃▃▄▅▄▂▄▄▃█▅▄▆▄▇▆▃▄▄▅▅▅▄▃▅▄▃▅
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss,█▇▃▅▂▂▁▃▁▂
val_accuracy,▁▂▇▄▇▇█▆█▇


Running trial #2


Epoch 0 is finished.
Training metrics:
	loss: 0.6121
	accuracy: 0.6408

Validation metrics:
	loss: 0.6184
	accuracy: 0.5903

Epoch 1 is finished.
Training metrics:
	loss: 0.5214
	accuracy: 0.6784

Validation metrics:
	loss: 0.6626
	accuracy: 0.5312

Epoch 2 is finished.
Training metrics:
	loss: 0.5010
	accuracy: 0.6903

Validation metrics:
	loss: 0.5168
	accuracy: 0.7196

Epoch 3 is finished.
Training metrics:
	loss: 0.4839
	accuracy: 0.7023

Validation metrics:
	loss: 0.5878
	accuracy: 0.6432

Epoch 4 is finished.
Training metrics:
	loss: 0.4702
	accuracy: 0.7084

Validation metrics:
	loss: 0.5208
	accuracy: 0.7231

Epoch 5 is finished.
Training metrics:
	loss: 0.4568
	accuracy: 0.7125

Validation metrics:
	loss: 0.5060
	accuracy: 0.7378

Epoch 6 is finished.
Training metrics:
	loss: 0.4432
	accuracy: 0.7245

Validation metrics:
	loss: 0.5318
	accuracy: 0.6970

Epoch 7 is finished.
Training metrics:
	loss: 0.4287
	accuracy: 0.7221

Validation metrics:
	loss: 0.5068
	accuracy: 0.7240



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss,0.36757
accuracy,0.78125
epoch,9.0
_runtime,532.0
_timestamp,1621531490.0
_step,71758.0
val_loss,0.52494
val_accuracy,0.76649
Test/Accuracy,0.73833
Test/F1,0.72504


0,1
loss,██▇▄▇▆▇▅▆▃▅█▃▄▅▄▅▆▄▄▃▃▇▄▃▆▂▄▄▃▃▃▃▅▄▂▁▁▃▂
accuracy,▅▂▁▇▇▅▂▄▃▄▆▅▆▄▂▆▂▂▄▄▃█▇▇▇▃█▆▂█▄▅▅▅▆▆▇▆▅▆
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss,▆█▃▅▃▂▃▂▁▃
val_accuracy,▃▁▇▄▇▇▆▇█▇


Running trial #3


Epoch 0 is finished.
Training metrics:
	loss: 0.5824
	accuracy: 0.6450

Validation metrics:
	loss: 0.6858
	accuracy: 0.5078

Epoch 1 is finished.
Training metrics:
	loss: 0.5167
	accuracy: 0.6848

Validation metrics:
	loss: 0.5995
	accuracy: 0.6146

Epoch 2 is finished.
Training metrics:
	loss: 0.4948
	accuracy: 0.6928

Validation metrics:
	loss: 0.5197
	accuracy: 0.7092

Epoch 3 is finished.
Training metrics:
	loss: 0.4778
	accuracy: 0.7085

Validation metrics:
	loss: 0.5495
	accuracy: 0.6597

Epoch 4 is finished.
Training metrics:
	loss: 0.4612
	accuracy: 0.7187

Validation metrics:
	loss: 0.5033
	accuracy: 0.7231

Epoch 5 is finished.
Training metrics:
	loss: 0.4492
	accuracy: 0.7304

Validation metrics:
	loss: 0.5161
	accuracy: 0.7196

Epoch 6 is finished.
Training metrics:
	loss: 0.4339
	accuracy: 0.7299

Validation metrics:
	loss: 0.5631
	accuracy: 0.6571

Epoch 7 is finished.
Training metrics:
	loss: 0.4169
	accuracy: 0.7315

Validation metrics:
	loss: 0.5088
	accuracy: 0.7109



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss,0.37286
accuracy,0.71875
epoch,9.0
_runtime,532.0
_timestamp,1621532028.0
_step,71758.0
val_loss,0.49844
val_accuracy,0.72309
Test/Accuracy,0.75833
Test/F1,0.77308


0,1
loss,█▄▆▂▆▅▆▄▆▂▄▇▃▃▄▃▄▅▃▄▃▂▆▃▃▄▃▃▂▃▄▃▂▅▃▂▁▂▃▂
accuracy,▂▅▁▅▆▅▂▃▅▄▆▅▄▃▃▇▄▃▄▆▃▇▅▅▇▄█▆▃▄▅▆▅▃▄▆▇▅▆▅
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss,█▅▂▃▁▂▃▁▂▁
val_accuracy,▁▄█▆██▆███


Running trial #4


Epoch 0 is finished.
Training metrics:
	loss: 0.5934
	accuracy: 0.6358

Validation metrics:
	loss: 0.7051
	accuracy: 0.5078

Epoch 1 is finished.
Training metrics:
	loss: 0.5203
	accuracy: 0.6807

Validation metrics:
	loss: 0.5560
	accuracy: 0.6649

Epoch 2 is finished.
Training metrics:
	loss: 0.5001
	accuracy: 0.6875

Validation metrics:
	loss: 0.5123
	accuracy: 0.7266

Epoch 3 is finished.
Training metrics:
	loss: 0.4795
	accuracy: 0.7047

Validation metrics:
	loss: 0.5249
	accuracy: 0.6988

Epoch 4 is finished.
Training metrics:
	loss: 0.4634
	accuracy: 0.7167

Validation metrics:
	loss: 0.4907
	accuracy: 0.7439

Epoch 5 is finished.
Training metrics:
	loss: 0.4507
	accuracy: 0.7170

Validation metrics:
	loss: 0.5006
	accuracy: 0.7188

Epoch 6 is finished.
Training metrics:
	loss: 0.4353
	accuracy: 0.7253

Validation metrics:
	loss: 0.5459
	accuracy: 0.6892

Epoch 7 is finished.
Training metrics:
	loss: 0.4194
	accuracy: 0.7336

Validation metrics:
	loss: 0.5388
	accuracy: 0.7031



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss,0.36192
accuracy,0.65625
epoch,9.0
_runtime,532.0
_timestamp,1621532565.0
_step,71758.0
val_loss,0.56108
val_accuracy,0.74392
Test/Accuracy,0.73
Test/F1,0.72542


0,1
loss,▇▄▅▄▆▅█▅▆▃▄▇▄▄▅▂▆▅▄▄▃▃▆▄▄▅▂▃▃▅▄▃▃▇▃▂▁▁▄▂
accuracy,▄▅▆▅▇▅▁▅▅▃▆▄▅▄▅▆▄▄▄▄▄█▆▅▇▄█▅▄▆▅▅▅▅▆▅▅▅▅▄
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss,█▃▂▂▁▁▃▃▅▃
val_accuracy,▁▆▇▇█▇▆▇▅▆


Running trial #5


Epoch 0 is finished.
Training metrics:
	loss: 0.6050
	accuracy: 0.6446

Validation metrics:
	loss: 0.6262
	accuracy: 0.5729

Epoch 1 is finished.
Training metrics:
	loss: 0.5177
	accuracy: 0.6887

Validation metrics:
	loss: 0.5712
	accuracy: 0.6493

Epoch 2 is finished.
Training metrics:
	loss: 0.4980
	accuracy: 0.6893

Validation metrics:
	loss: 0.4936
	accuracy: 0.7465

Epoch 3 is finished.
Training metrics:
	loss: 0.4800
	accuracy: 0.7092

Validation metrics:
	loss: 0.5347
	accuracy: 0.6953

Epoch 4 is finished.
Training metrics:
	loss: 0.4650
	accuracy: 0.7202

Validation metrics:
	loss: 0.4716
	accuracy: 0.7943

Epoch 5 is finished.
Training metrics:
	loss: 0.4488
	accuracy: 0.7259

Validation metrics:
	loss: 0.4619
	accuracy: 0.7943

Epoch 6 is finished.
Training metrics:
	loss: 0.4342
	accuracy: 0.7377

Validation metrics:
	loss: 0.4730
	accuracy: 0.7630

Epoch 7 is finished.
Training metrics:
	loss: 0.4167
	accuracy: 0.7403

Validation metrics:
	loss: 0.4355
	accuracy: 0.7986



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss,0.34744
accuracy,0.75
epoch,9.0
_runtime,533.0
_timestamp,1621533104.0
_step,71758.0
val_loss,0.50112
val_accuracy,0.79861
Test/Accuracy,0.73
Test/F1,0.69434


0,1
loss,█▅▆▃▆▃▆▃▅▃▄▆▃▄▅▂▄▄▃▃▃▂▅▄▂▄▂▃▄▃▃▄▂▆▃▂▁▁▃▁
accuracy,▄▄▄▆▇▇▁▅▄▄▆▅▇▃▄▇▆▄▃▆▄▇▇▄▇▆█▅▄▇▅▅▅▄▆▆▇▇▄▆
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss,█▆▃▅▂▂▂▁▂▃
val_accuracy,▁▃▆▅██▇██▆





## Train model on translated data

In [None]:
try:
    qqp_ru = pd.read_csv('qqp_ru_train.tsv', sep='\t')
except:
    qqp_ru = qqp_df.copy().reset_index(drop=True)
    qqp_ru['question1'] = [None] * len(qqp_ru)
    qqp_ru['question2'] = [None] * len(qqp_ru)
    for i in range(len(qqp_ru)):
        q_id = qqp_ru.at[i, 'qid1']
        qqp_ru.at[i, 'question1'] = translated.get(q_id, float('nan'))

        q_id = qqp_ru.at[i, 'qid2']
        qqp_ru.at[i, 'question2'] = translated.get(q_id, float('nan'))
    qqp_ru.dropna().to_csv('qqp_ru_train.tsv', sep='\t', index=False, header=True)

In [None]:
class SkipTrainer(Trainer):
    def train_step(
        self,
        model: nn.Module, batch,
        criterion, optimizer,
        it: Optional[int] = None
        ) -> MutableMapping[str, Optional[float]]:

        if hasattr(model.model.config, "max_position_embeddings"):
            max_len = model.model.config.max_position_embeddings
        else:
            max_len = 512

        if batch.text[0].size(1) > max_len:
            return {}
        else:
            return super().train_step(model, batch, criterion, optimizer, it)

In [None]:
model_name = 'DeepPavlov/rubert-base-cased'
BATCH_SIZE = 32

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
unk_index = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
cls_index = tokenizer.convert_tokens_to_ids(tokenizer.cls_token)
sep_index = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)

In [None]:
PAIR = SeqPairField(tokenizer)
LABEL = LabelField()

examples = []
for _, row in qqp_ru.dropna().iterrows():
    ex = text_data.Example.fromlist(
        [(row['question1'], row['question2']), row['is_duplicate']],
        [('text', PAIR), ('label', LABEL)]
        )
    examples.append(ex)

train_dataset = text_data.Dataset(
    deepcopy(examples),
    {
        'text': PAIR,
        'label': LABEL
    }
)

train_iterator = text_data.BucketIterator(
    train_dataset,
    BATCH_SIZE,
    sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
    device=device,
    shuffle=True
)
val_iterator = text_data.BucketIterator(
    get_val_dataset(tokenizer),
    BATCH_SIZE,
    sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
    device=device,
    shuffle=False
)

In [None]:
config = {
    'HuggingFace model': model_name,
    'Translation model': 'facebook/wmt19-en-ru',
    'LR': 5e-5,
    'Epochs': 3,
    'Dropout rate': 0.6,
    'Freezed layers': 10
}

In [None]:
def run_trial(
    seed, tokenizer,
    trial_config,
    project_name: str = 'coursework',
    group: str = 'Transalted QQP'
):
    pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    cudnn_det_cache = None
    if torch.cuda.is_available():
        cudnn_det_cache = torch.backends.cudnn.deterministic
        torch.backends.cudnn.deterministic = True

    train_iterator = text_data.BucketIterator(
        train_dataset,
        BATCH_SIZE,
        sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
        device=device,
        shuffle=True
    )
    val_iterator = text_data.BucketIterator(
        get_val_dataset(tokenizer),
        BATCH_SIZE,
        sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
        device=device,
        shuffle=False
    )

    with wandb.init(project=project_name, config=trial_config, group=group, save_code=False):
        # Init model
        model = TransformerCls(
            model_name,
            2,
            dropout=wandb.config['Dropout rate'],
            freeze=wandb.config['Freezed layers'],
        ).to(device)

        criterion = nn.CrossEntropyLoss()
        optim = torch.optim.Adam(model.parameters(), lr=wandb.config['LR'])

        # Train model
        SkipTrainer(pad_index=pad_index, silent=True).train(
            model,
            train_iterator, val_iterator,
            criterion, optim,
            total_epochs=wandb.config['Epochs']
        )

        # Test scores
        test_iterator = text_data.BucketIterator(
            get_val_dataset(tokenizer, path='test.tsv', shuffle=True),
            BATCH_SIZE,
            sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
            device=device,
            shuffle=False
        )

        labels = []
        preds = []
        scores = []

        model.eval()
        with torch.no_grad():
            for batch in test_iterator:
                tokens, token_type_ids = batch.text
                if pad_index is None:
                    mask = None
                else:
                    mask = (tokens != pad_index).float()
                output = model(tokens, mask, token_type_ids)

                scores += torch.softmax(output, dim=1)[:, 1].cpu().tolist()
                batch_preds = torch.argmax(output, dim=-1)
                labels += batch.label.cpu().tolist()
                preds += batch_preds.cpu().tolist()

        test_acc = accuracy_score(labels, preds)
        wandb.run.summary['Test/Accuracy'] = test_acc
        print(f'Test set Accuracy: {test_acc:.4f}')

        test_f1 = f1_score(labels, preds)
        wandb.run.summary['Test/F1'] = test_f1
        print(f'Test set F1 Score: {test_f1:.4f}')

        test_roc_auc = roc_auc_score(labels, scores)
        wandb.run.summary['Test/ROC-AUC'] = test_roc_auc
        print(f'Test set ROC-AUC Score: {test_roc_auc:.4f}')

    if cudnn_det_cache is not None:
        torch.backends.cudnn.deterministic = cudnn_det_cache

In [None]:
for trial_n, s in enumerate(tqdm(SEEDS, unit='trial')):
    print(f'Running trial #{trial_n + 1}')
    run_trial(s, tokenizer, config)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

## Balanced Dev

In [None]:
dev_df = pd.read_csv('dev.tsv', sep='\t')
pos_dev = dev_df[dev_df['class'] == 1]
neg_dev = dev_df[dev_df['class'] == 0].sample(n=len(pos_dev))
balanced_dev = pd.concat([pos_dev, neg_dev]).sample(frac=1.).reset_index(drop=True)
balanced_dev.to_csv('balanced_dev.tsv', sep='\t', index=False, header=True)

In [None]:
val_iterator = text_data.BucketIterator(
    get_val_dataset(tokenizer, path='balanced_dev.tsv'),
    BATCH_SIZE,
    sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
    device=device,
    shuffle=False
)

In [None]:
config = {
    'HuggingFace model': model_name,
    'Translation model': 'facebook/wmt19-en-ru',
    'LR': 5e-5,
    'Epochs': 3,
    'Dropout rate': 0.6,
    'Freezed layers': 10
}

In [None]:
for trial_n, s in enumerate(tqdm(SEEDS, unit='trial')):
    print(f'Running trial #{trial_n + 1}')
    run_trial(s, tokenizer, config, group='Transalted QQP (Balanced dev)')
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# LaBSE

In [None]:
class LABSECosineTrainer(Trainer):
    @staticmethod
    def cosine_margin_loss(
        x1: torch.Tensor,
        x2: torch.Tensor,
        target: torch.Tensor
    ) -> torch.Tensor:
        sim = F.cosine_similarity(x1, x2)
        pos_loss = target * (1 - sim)
        neg_loss = (1 - target) * sim
        return torch.mean(pos_loss + neg_loss)

    def train_step(
        self,
        model: nn.Module, batch,
        criterion, optimizer,
        it: Optional[int] = None
        ) -> MutableMapping[str, Optional[float]]:
        
        if hasattr(model.model.config, "max_position_embeddings"):
            max_len = model.model.config.max_position_embeddings
        else:
            max_len = 512

        left_tokens, right_tokens = batch.text
        if left_tokens.size(1) > max_len or right_tokens.size(1) > max_len:
            return {}

        if self.pad_index is None:
            mask = None
        else:
            mask = (left_tokens != self.pad_index).float()
        left_out = model(left_tokens, mask)

        if self.pad_index is None:
            mask = None
        else:
            mask = (right_tokens != self.pad_index).float()
        right_out = model(right_tokens, mask)
        preds = (F.cosine_similarity(left_out, right_out) > 0.5).type(torch.long)

        loss = F.cosine_embedding_loss(
            left_out, right_out,
            torch.where(batch.label == 1, batch.label, -1)
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (preds == batch.label).type(torch.float).mean()
        return {
            'loss': loss.item(),
            'accuracy': acc.item()
        }

    def val_step(
        self,
        model: nn.Module, batch,
        criterion = None,
        it: Optional[int] = None) -> MutableMapping[str, Optional[float]]:
        
        if hasattr(model.model.config, "max_position_embeddings"):
            max_len = model.model.config.max_position_embeddings
        else:
            max_len = 512

        left_tokens, right_tokens = batch.text
        if left_tokens.size(1) > max_len or right_tokens.size(1) > max_len:
            return {}

        if self.pad_index is None:
            mask = None
        else:
            mask = (left_tokens != self.pad_index).float()
        left_out = model(left_tokens, mask)

        if self.pad_index is None:
            mask = None
        else:
            mask = (right_tokens != self.pad_index).float()
        right_out = model(right_tokens, mask)
        loss = self.cosine_margin_loss(left_out, right_out, batch.label.float())

        preds = (F.cosine_similarity(left_out, right_out) > 0.5).type(torch.long)
        acc = (preds == batch.label).type(torch.float).mean()
        return {
            'loss': loss.item(),
            'accuracy': acc.item()
        }

In [None]:
class LABSETrainer(Trainer):
    def train_step(
        self,
        model: nn.Module, batch,
        criterion, optimizer,
        it: Optional[int] = None
        ) -> MutableMapping[str, Optional[float]]:
        
        if hasattr(model.model.config, "max_position_embeddings"):
            max_len = model.model.config.max_position_embeddings
        else:
            max_len = 512

        left_tokens, right_tokens = batch.text
        if left_tokens.size(1) > max_len or right_tokens.size(1) > max_len:
            return {}

        if self.pad_index is None:
            left_mask = None
            right_mask = None
        else:
            left_mask = (left_tokens != self.pad_index).float()
            right_mask = (right_tokens != self.pad_index).float()

        output = model(left_tokens, right_tokens, left_mask, right_mask)
        loss = criterion(output, batch.label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.argmax(output, dim=-1)
        acc = (preds == batch.label).type(torch.float).mean()
        return {
            'loss': loss.item(),
            'accuracy': acc.item()
        }

    def val_step(
        self,
        model: nn.Module, batch,
        criterion = None,
        it: Optional[int] = None) -> MutableMapping[str, Optional[float]]:
        
        if hasattr(model.model.config, "max_position_embeddings"):
            max_len = model.model.config.max_position_embeddings
        else:
            max_len = 512

        left_tokens, right_tokens = batch.text
        if left_tokens.size(1) > max_len or right_tokens.size(1) > max_len:
            return {}

        if self.pad_index is None:
            left_mask = None
            right_mask = None
        else:
            left_mask = (left_tokens != self.pad_index).float()
            right_mask = (right_tokens != self.pad_index).float()

        output = model(left_tokens, right_tokens, left_mask, right_mask)
        loss = criterion(output, batch.label)

        preds = torch.argmax(output, dim=-1)
        acc = (preds == batch.label).type(torch.float).mean()
        return {
            'loss': loss.item(),
            'accuracy': acc.item()
        }

In [None]:
class LABSECombinedTrainer(Trainer):
    def __init__(self,
                 pad_index: Optional[int] = None,
                 silent: bool = False,
                 alpha: float = 0.1) -> None:
        super().__init__(pad_index=pad_index, silent=silent)
        self.alpha = alpha

    @staticmethod
    def cosine_reg_loss(
        x1: torch.Tensor,
        x2: torch.Tensor,
        target: torch.Tensor
    ) -> torch.Tensor:
        mask = (target == 1)
        if mask.any():
            sim = F.cosine_similarity(x1[mask], x2[mask])
            return torch.mean(torch.clamp(1 - sim, min=0, max=1))
        else:
            return torch.zeros((1,), dtype=torch.float, device=x1.device)

    def train_step(
        self,
        model: nn.Module, batch,
        criterion, optimizer,
        it: Optional[int] = None
        ) -> MutableMapping[str, Optional[float]]:
        
        if hasattr(model.model.config, "max_position_embeddings"):
            max_len = model.model.config.max_position_embeddings
        else:
            max_len = 512

        left_tokens, right_tokens = batch.text
        if left_tokens.size(1) > max_len or right_tokens.size(1) > max_len:
            return {}

        if self.pad_index is None:
            left_mask = None
            right_mask = None
        else:
            left_mask = (left_tokens != self.pad_index).float()
            right_mask = (right_tokens != self.pad_index).float()

        output = model(
            left_tokens, right_tokens,
            left_mask, right_mask,
            return_dict=True
        )
        logits = output['logits']
        clf_loss = criterion(logits, batch.label)
        cosine_reg = self.cosine_reg_loss(
            output['embeddings'][:, 0, :],
            output['embeddings'][:, 1, :],
            batch.label
        )
        loss = clf_loss + self.alpha * cosine_reg

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.argmax(logits, dim=-1)
        acc = (preds == batch.label).type(torch.float).mean()
        return {
            'clf_loss': clf_loss.item(),
            'cosine_loss': cosine_reg.item(),
            'loss': loss.item(),
            'accuracy': acc.item()
        }

    def val_step(
        self,
        model: nn.Module, batch,
        criterion = None,
        it: Optional[int] = None) -> MutableMapping[str, Optional[float]]:
        
        if hasattr(model.model.config, "max_position_embeddings"):
            max_len = model.model.config.max_position_embeddings
        else:
            max_len = 512

        left_tokens, right_tokens = batch.text
        if left_tokens.size(1) > max_len or right_tokens.size(1) > max_len:
            return {}

        if self.pad_index is None:
            left_mask = None
            right_mask = None
        else:
            left_mask = (left_tokens != self.pad_index).float()
            right_mask = (right_tokens != self.pad_index).float()

        
        if criterion is None:
            logits = model(
                left_tokens, right_tokens,
                left_mask, right_mask,
                return_dict=False
            )

            step_log = {}
        else:
            output = model(
                left_tokens, right_tokens,
                left_mask, right_mask,
                return_dict=True
            )
            logits = output['logits']
            clf_loss = criterion(logits, batch.label)
            cosine_reg = self.cosine_reg_loss(
                output['embeddings'][:, 0, :],
                output['embeddings'][:, 1, :],
                batch.label
            )
            loss = clf_loss + self.alpha * cosine_reg
            step_log = {
                'clf_loss': clf_loss.item(),
                'cosine_loss': cosine_reg.item(),
                'loss': loss.item()
            }

        preds = torch.argmax(logits, dim=-1)
        acc = (preds == batch.label).type(torch.float).mean()
        step_log['accuracy'] = acc.item()

        return step_log

In [None]:
model_name = 'sentence-transformers/LaBSE'
BATCH_SIZE = 32

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
unk_index = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
cls_index = tokenizer.convert_tokens_to_ids(tokenizer.cls_token)
sep_index = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)

In [None]:
PAIR = PairField(tokenizer)
LABEL = LabelField()

qqp_df = pd.read_csv('qqp_train.csv')
examples = []
for _, row in qqp_df.dropna().iterrows():
    ex = text_data.Example.fromlist(
        [(row['question1'], row['question2']), row['is_duplicate']],
        [('text', PAIR), ('label', LABEL)]
    )
    examples.append(ex)

train_dataset = text_data.Dataset(
    deepcopy(examples),
    {
        'text': PAIR,
        'label': LABEL
    }
)

train_iterator = text_data.BucketIterator(
    train_dataset,
    BATCH_SIZE,
    sort_key=lambda x: max(len(x.text[0].split()), len(x.text[1].split())),
    device=device,
    shuffle=True
)
val_iterator = text_data.BucketIterator(
    get_val_dataset(tokenizer, field_cls=PairField),
    BATCH_SIZE,
    sort_key=lambda x: max(len(x.text[0].split()), len(x.text[1].split())),
    device=device,
    shuffle=False
)

In [None]:
config = {
    'HuggingFace model': model_name,
    'LR': 1e-5,
    'Epochs': 3,
    'Freezed layers': 10,
    'alpha': 0.1
}

In [None]:
def run_trial(seed, tokenizer, trial_config, project_name: str = 'coursework'):
    pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    cudnn_det_cache = None
    if torch.cuda.is_available():
        cudnn_det_cache = torch.backends.cudnn.deterministic
        torch.backends.cudnn.deterministic = True

    train_iterator = text_data.BucketIterator(
        train_dataset,
        BATCH_SIZE,
        sort_key=lambda x: max(len(x.text[0].split()), len(x.text[1].split())),
        device=device,
        shuffle=True
    )
    val_iterator = text_data.BucketIterator(
        get_val_dataset(tokenizer, field_cls=PairField),
        BATCH_SIZE,
        sort_key=lambda x: max(len(x.text[0].split()), len(x.text[1].split())),
        device=device,
        shuffle=False
    )

    with wandb.init(project=project_name, config=trial_config, group='LaBSE QQP', save_code=False):
        # Downgrade to previous LaBSE version
        if model_name == 'sentence-transformers/LaBSE':
            revision = '4572cc21315c3dda0520dce8663d657558900936'
        elif model_name == 'pvl/labse_bert':
            revision = '3cc5598da072e136eb7e529bbacc6628204f6cd1'
        else:
            revision = None

        # Init model
        model = TransformerWrapper(
            wandb.config['HuggingFace model'],
            freeze=wandb.config['Freezed layers'],
            aggregate_n_last_hidden_layers=1,
            revision=revision
        ).to(device)

        criterion = None
        optim = torch.optim.AdamW(model.parameters(), lr=wandb.config['LR'])

        # Train model
        LABSETrainer(pad_index=pad_index, silent=True).train(
            model,
            train_iterator, val_iterator,
            criterion, optim,
            total_epochs=wandb.config['Epochs']
        )

        # Test scores
        test_iterator = text_data.BucketIterator(
            get_val_dataset(tokenizer, path='test.tsv', shuffle=True, field_cls=PairField),
            BATCH_SIZE,
            sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
            device=device,
            shuffle=False
        )

        labels = []
        preds = []
        scores = []

        model.eval()
        with torch.no_grad():
            for batch in test_iterator:
                left_tokens, right_tokens = batch.text
                if pad_index is None:
                    mask = None
                else:
                    mask = (left_tokens != pad_index).float()
                left_out = model(left_tokens, mask)

                if pad_index is None:
                    mask = None
                else:
                    mask = (right_tokens != pad_index).float()
                right_out = model(right_tokens, mask)

                batch_scores = F.cosine_similarity(left_out, right_out)
                scores += batch_scores.cpu().tolist()

                batch_preds = (batch_scores > 0.5).type(torch.long)
                labels += batch.label.cpu().tolist()
                preds += batch_preds.cpu().tolist()

        test_acc = accuracy_score(labels, preds)
        wandb.run.summary['Test/Accuracy'] = test_acc
        print(f'Test set Accuracy: {test_acc:.4f}')

        test_f1 = f1_score(labels, preds)
        wandb.run.summary['Test/F1'] = test_f1
        print(f'Test set F1 Score: {test_f1:.4f}')

        test_roc_auc = roc_auc_score(labels, scores)
        wandb.run.summary['Test/ROC-AUC'] = test_roc_auc
        print(f'Test set ROC-AUC Score: {test_roc_auc:.4f}')
    if cudnn_det_cache is not None:
        torch.backends.cudnn.deterministic = cudnn_det_cache

In [None]:
def run_trial(seed, tokenizer, trial_config, project_name: str = 'coursework'):
    pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    cudnn_det_cache = None
    if torch.cuda.is_available():
        cudnn_det_cache = torch.backends.cudnn.deterministic
        torch.backends.cudnn.deterministic = True

    train_iterator = text_data.BucketIterator(
        train_dataset,
        BATCH_SIZE,
        sort_key=lambda x: max(len(x.text[0].split()), len(x.text[1].split())),
        device=device,
        shuffle=True
    )
    val_iterator = text_data.BucketIterator(
        get_val_dataset(tokenizer, field_cls=PairField),
        BATCH_SIZE,
        sort_key=lambda x: max(len(x.text[0].split()), len(x.text[1].split())),
        device=device,
        shuffle=False
    )

    with wandb.init(project=project_name, config=trial_config, group='LaBSE QQP', save_code=False):
        # Init model
        model = SiameseClf(
            wandb.config['HuggingFace model'],
            2,
            freeze=wandb.config['Freezed layers'],
            aggregate_n_last_hidden_layers=1,
        ).to(device)

        criterion = nn.CrossEntropyLoss()
        optim = torch.optim.AdamW(model.parameters(), lr=wandb.config['LR'])

        # Train model
        LABSECombinedTrainer(
            pad_index=pad_index,
            silent=True,
            alpha=wandb.config['alpha']
        ).train(
            model,
            train_iterator, val_iterator,
            criterion, optim,
            total_epochs=wandb.config['Epochs']
        )

        # Test scores
        test_iterator = text_data.BucketIterator(
            get_val_dataset(tokenizer, path='test.tsv', shuffle=True, field_cls=PairField),
            BATCH_SIZE,
            sort_key=lambda x: len(x.text[0].split()) + len(x.text[1].split()),
            device=device,
            shuffle=False
        )

        labels = []
        preds = []
        scores = []

        model.eval()
        with torch.no_grad():
            for batch in test_iterator:
                left_tokens, right_tokens = batch.text
                if pad_index is None:
                    left_mask = None
                    right_mask = None
                else:
                    left_mask = (left_tokens != pad_index).float()
                    right_mask = (right_tokens != pad_index).float()

                output = model(left_tokens, right_tokens, left_mask, right_mask)

                batch_scores = torch.softmax(output, dim=1)[:, 1]
                scores += batch_scores.cpu().tolist()

                batch_preds = torch.argmax(output, dim=1)
                labels += batch.label.cpu().tolist()
                preds += batch_preds.cpu().tolist()

        test_acc = accuracy_score(labels, preds)
        wandb.run.summary['Test/Accuracy'] = test_acc
        print(f'Test set Accuracy: {test_acc:.4f}')

        test_f1 = f1_score(labels, preds)
        wandb.run.summary['Test/F1'] = test_f1
        print(f'Test set F1 Score: {test_f1:.4f}')

        test_roc_auc = roc_auc_score(labels, scores)
        wandb.run.summary['Test/ROC-AUC'] = test_roc_auc
        print(f'Test set ROC-AUC Score: {test_roc_auc:.4f}')
    if cudnn_det_cache is not None:
        torch.backends.cudnn.deterministic = cudnn_det_cache

In [None]:
for trial_n, s in enumerate(tqdm(SEEDS, unit='trial')):
    print(f'Running trial #{trial_n + 1}')
    run_trial(s, tokenizer, config)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()