Special Thanks to Z by HP & NVIDIA for sponsoring me a Z4 Workstation with dual A6000 GPU!

Training large transformer models with 48GB VRAM is awesome!

In [None]:
from argparse import ArgumentParser
from lightgbm import train
import numpy as np
import pandas as pd
import random

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

from transformers import AutoModel, AutoTokenizer, AutoConfig

import wandb

from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import callbacks
from pytorch_lightning.callbacks.progress import ProgressBarBase
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from pytorch_lightning import LightningDataModule, LightningModule
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler

from argparse import Namespace
from scipy.stats import rankdata

from tqdm.notebook import tqdm

In [None]:
test_df = pd.read_pickle('../input/jigsaw-data/valid_text.pkl')
test_mapping = pd.read_csv('../input/jigsaw-data/valid_set.csv')
pairs = pd.read_pickle('../input/jigsaw-data/pairs_v2_v4.pkl')

In [None]:
pairs.head()

In [None]:
seed_everything(1991)

In [None]:
def make_anchor(anchor_per_group=50, fold=0):
    anchors, score_weights = [], []
    for group in pairs['group'].unique():
        temp = pairs[pairs['group']==group]
        divider = int(len(temp)/anchor_per_group)
        a_text = temp[temp['uni_rank']%divider==fold]['text'].tolist()
        anchors.extend(a_text)
        score_divider = int(len(temp)/10)
        a_weight = (temp[temp['uni_rank']%divider==fold]['uni_rank']//score_divider+1).tolist()
        score_weights.extend(a_weight)
    return anchors, score_weights

In [None]:
class JigsawPairValidDataset(Dataset):
    def __init__(self, df):
        self._X = df["text"].values
        self._id = df.index.tolist()
        self._y = None
        if "toxic" in df.keys():
            self._y = df["toxic"].values

    def __len__(self):
        return len(self._X)

    def __getitem__(self, idx):
        text = self._X[idx]
        
        return text, torch.FloatTensor([idx])
    
class ValidCollator:

    def __init__(self, config, anchors):
        super().__init__()        
        self.max_length = config.max_length
        self.tokenizer = AutoTokenizer.from_pretrained(config.backbone_name)
        self.token_pad_value = self.tokenizer.pad_token_id
        self.type_pad_value = self.tokenizer.pad_token_type_id
        self.anchors = anchors
        self.n_anchors = len(anchors)
        
    def __call__(self, batch):
        text, label = zip(*batch)

        features_a = self.tokenizer.batch_encode_plus(list(text), 
                                                    return_tensors='pt',
                                                    padding='max_length',
                                                    max_length=self.max_length, truncation=True)
        features_b = self.tokenizer.batch_encode_plus(list(self.anchors), 
                                                    return_tensors='pt',
                                                    padding='max_length',
                                                    max_length=self.max_length, truncation=True)
        label = torch.stack(label)
        return features_a, features_b, label

In [None]:
class JigsawPairModel(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        if config.predict:
            bert_config = AutoConfig.from_pretrained(config.backbone_name)
            self.bert = AutoModel.from_config(bert_config)
        else:
            self.bert = AutoModel.from_pretrained(config.backbone_name)
        
        self.head = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.bert.config.hidden_size*3, 1)
        )

        self.save_hyperparameters(config)
        self.anchor_outputs = None

    def masked_mean_pooling(self, emb, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(emb.size()).float()
        sum_embeddings = torch.sum(emb * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        out = sum_embeddings / sum_mask
        return out
    
    def inference(self, seq1, seq2, mask1=None, mask2=None):
        # seq1: B*l1*D
        # seq2: B*l2*D
        # mask1: B*l1
        # mask2: B*l2
        score = torch.bmm(seq1, seq2.permute(0, 2, 1))
        new_seq1 = torch.bmm(torch.softmax(score, dim=-1), seq2*mask2.unsqueeze(-1)) #
        new_seq1 = torch.sum(new_seq1*mask1.unsqueeze(-1),dim=1)/torch.sum(mask1, dim=1).unsqueeze(-1)
        # del score1

        new_seq2 = torch.bmm(torch.softmax(score, dim=1).permute(0, 2, 1), seq1*mask1.unsqueeze(-1)) #
        new_seq2 = torch.sum(new_seq2*mask2.unsqueeze(-1), dim=1)/torch.sum(mask2, dim=1).unsqueeze(-1)
        return new_seq1, new_seq2
    
    def forward(self, x, anchors, score_weights):
        out1 = self.bert(**x)
        if self.anchor_outputs is None:
            self.anchor_outputs = self.bert(**anchors)
            if not hasattr(self.anchor_outputs, 'pooler_output'):
                self.anchor_outputs.pooler_output = self.masked_mean_pooling(self.anchor_outputs.last_hidden_state, anchors['attention_mask'])
            
        if not hasattr(out1, 'pooler_output'):
            out1.pooler_output = self.masked_mean_pooling(out1.last_hidden_state, x['attention_mask'])
              
        scores = []
        for i in range(len(out1.pooler_output)):
            sample_out = out1.pooler_output[i,:] 
            if self.config.inference:
                sample_hidden_out = out1.last_hidden_state[i,:,:].unsqueeze(0).expand_as(self.anchor_outputs.last_hidden_state)

                new_anchor_emb, new_sample_emb = self.inference(self.anchor_outputs.last_hidden_state, sample_hidden_out, 
                                                anchors['attention_mask'], x['attention_mask'][i].unsqueeze(0))
                pred = torch.sigmoid(self.head(torch.cat([
                                            new_sample_emb,
                                            new_anchor_emb,
                                            self.anchor_outputs.pooler_output-sample_out], axis=1))).flatten().cpu().numpy()
            else:
                pred = torch.sigmoid(self.head(torch.cat([
                                            self.anchor_outputs.pooler_output,
                                            sample_out.expand_as(self.anchor_outputs.pooler_output),
                                            self.anchor_outputs.pooler_output-sample_out], axis=1))).flatten().cpu().numpy()
            score = np.dot(pred, score_weights)
            scores.append(score)
        return np.array(scores)

In [None]:
def get_score(all_pred):
    test_mapping['score1'] = test_mapping['text1'].apply(lambda x: all_pred[x])
    test_mapping['score2'] = test_mapping['text2'].apply(lambda x: all_pred[x])
    
    match, real_match = 0, 0
    for idx1, idx2, real_label in zip(test_mapping['text1'].tolist(), test_mapping['text2'].tolist(), test_mapping['real_label'].tolist()):
        if idx1>=len(all_pred) or idx2>=len(all_pred):
            continue
        if all_pred[idx1]<all_pred[idx2]:
            match+=1
        if int(all_pred[idx1]<all_pred[idx2])==real_label:
            real_match +=1
    print(f'score: {match/len(test_mapping)}, real score: {real_match/len(test_mapping)}')

In [None]:
def predict(config, to_pred):
    anchors, score_weights = make_anchor(fold=config.fold)
    
    dataset = JigsawPairValidDataset(to_pred)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, 
            num_workers=2,
            collate_fn=ValidCollator(config, anchors))
    model = JigsawPairModel.load_from_checkpoint(config.path, config=config, strict=False)
    model.eval()
    model.cuda()
    
    preds = []
    with torch.no_grad():
        for x, a, _ in tqdm(dataloader):
            pred = model(x.to('cuda'), a.to('cuda'), score_weights)
            preds.append(pred)
    preds = np.concatenate(preds)
    return preds

In [None]:
rbase_config = Namespace(
    backbone_name='../input/rbase-inf/',
    max_length=128,
    batch_size=128,
    predict=True,
    focal=False,
    inference=True,
    fold=0,
    path='../input/rbase-inf/roberta-base_0-v1.ckpt'
)

In [None]:
dbase_config = Namespace(
    backbone_name='../input/dbase-late/',
    max_length=128,
    batch_size=128,
    predict=True,
    focal=False,
    inference=True,
    fold=1,
    path='../input/dbase-late/deberta-v3_0.ckpt'
)

In [None]:
ebase_config = Namespace(
    backbone_name='../input/ebase-inf/',
    max_length=128,
    batch_size=128,
    predict=True,
    focal=False,
    inference=True,
    fold=2,
    path='../input/ebase-inf/electra-base_0.ckpt'
)

In [None]:
finter_config = Namespace(
    backbone_name='../input/finter-inf/',
    max_length=128,
    batch_size=128,
    predict=True,
    focal=False,
    inference=True,
    fold=3,
    path='../input/finter-inf/funnel-intermediate_0.ckpt'
)

In [None]:
xbase_config = Namespace(
    backbone_name='../input/xbase-inf/',
    max_length=128,
    batch_size=128,
    predict=True,
    focal=False,
    inference=True,
    fold=4,
    path='../input/xbase-inf/xlnet-base_0.ckpt'
)

## validation

In [None]:
# finter_config = predict(finter_config, test_df)
# ebase_score = predict(ebase_config, test_df)
# dbase_score = predict(dbase_config, test_df)
# rbase_score = predict(rbase_config, test_df)

In [None]:
# np.corrcoef([finter_config,ebase_score,dbase_score,rbase_score])

In [None]:
# get_score(np.mean([finter_config,ebase_score,dbase_score,rbase_score], axis=0))

In [None]:
# get_score(rbase_score)

## test

In [None]:
comment_to_score = pd.read_csv('../input/jigsaw-toxic-severity-rating/comments_to_score.csv')

In [None]:
comment_to_score.head()

In [None]:
# finter_score = predict(finter_config, comment_to_score)
# ebase_score = predict(ebase_config, comment_to_score)
dbase_score = predict(dbase_config, comment_to_score)
# rbase_score = predict(rbase_config, comment_to_score)
# xbase_score = predict(xbase_config, comment_to_score)

In [None]:
# rank = np.mean([
#     rankdata(rbase_score), rankdata(dbase_score), rankdata(ebase_score),
#     rankdata(xbase_score), rankdata(xbase_score),rankdata(xbase_score),
#     rankdata(finter_score),rankdata(finter_score)], axis=0)

In [None]:
# np.corrcoef([rbase_score, dbase_score, ebase_score, xbase_score, finter_score])

In [None]:
comment_to_score['score'] = rankdata(dbase_score)

In [None]:
comment_to_score.head()

In [None]:
comment_to_score[['comment_id', 'score']].to_csv('submission.csv', index=False)