In [1]:
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
from transformers import AutoTokenizer, AutoModelForSequenceClassification,AutoModel
import os
from tqdm import tqdm
import gc

In [2]:
class CONFIG:
    models = [x for x in os.listdir("../input/fork-of-pytorch-lightning-toxic-bert") if "toxicbert" in x]
    model_name = "distilroberta-base"
    seed = 101
    k_fold = 5
    val_batch_size = 64
    max_len = 128
    tokenizer = AutoTokenizer.from_pretrained("../input/unitarytoxicbert")
    no_class = 1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [3]:
df = pd.read_csv("../input/jigsaw-toxic-severity-rating/comments_to_score.csv")
df.head()

Unnamed: 0,comment_id,text
0,114890,"""\n \n\nGjalexei, you asked about whether ther..."
1,732895,"Looks like be have an abuser , can you please ..."
2,1139051,I confess to having complete (and apparently b...
3,1434512,"""\n\nFreud's ideas are certainly much discusse..."
4,2084821,It is not just you. This is a laundry list of ...


In [4]:
class JigsawDataset(Dataset):
    def __init__(self, df, tokenizer, max_seq_len):
        self.df = df
        self.max_seq_len = max_seq_len
        self.tokenizer = tokenizer
        self.text = df['text'].values
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        text = self.text[index]
        inputs_text = self.tokenizer.encode_plus(
                                text,
                                truncation=True,
                                add_special_tokens=True,
                                max_length=self.max_seq_len,
                                padding='max_length'
                            )
        score_text_id = inputs_text['input_ids']
        score_text_mask = inputs_text['attention_mask']

        
        return {
            'text_ids': torch.tensor(score_text_id, dtype=torch.long),
            'text_mask': torch.tensor(score_text_mask, dtype=torch.long)
        }

In [5]:
class jigsaw_toxicbert(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = AutoModel.from_pretrained("../input/unitarytoxicbert/")
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(768,128)
        self.dropout_2 = nn.Dropout(0.2)
        self.output = nn.Linear(128,CONFIG.no_class)
    def forward(self,ids,mask):
        x = self.model(input_ids=ids,attention_mask=mask,output_hidden_states=False)
        x = self.dropout(x[1])
        x=  self.linear(x)
        x=  self.dropout_2(x)
        x = self.output(x)
        return x
    
    def configure_optimizers(self):
        optimiser = optim.AdamW(self.model.parameters(), lr=CONFIG.lr, weight_decay=CONFIG.weight_decay)
        scheduler = lr_scheduler.CosineAnnealingLR(optimiser,eta_min=CONFIG.min_lr,T_max = CONFIG.T_max)
        return [optimiser],[scheduler]
    
    def training_step(self, batch, batch_idx):
        more_toxic_id = batch["more_toxic_ids"]
        more_toxic_mask = batch["more_toxic_mask"]
        less_toxic_id =batch["less_toxic_ids"]
        less_toxic_mask =batch["less_toxic_mask"]
        target = batch["target"]
        more_toxic_pred = self(more_toxic_id,more_toxic_mask)
        less_toxic_pred = self(less_toxic_id,more_toxic_mask)
        loss = CONFIG.criterion(more_toxic_pred,less_toxic_pred,target)
        self.log('train_margin_loss',loss,on_step=False, on_epoch=True,prog_bar=True)
        return loss  
    def validation_step(self, batch, batch_idx):
        more_toxic_id = batch["more_toxic_ids"]
        more_toxic_mask = batch["more_toxic_mask"]
        less_toxic_id =batch["less_toxic_ids"]
        less_toxic_mask =batch["less_toxic_mask"]
        target = batch["target"]
        more_toxic_pred = self(more_toxic_id,more_toxic_mask)
        less_toxic_pred = self(less_toxic_id,more_toxic_mask)
        loss = CONFIG.criterion(more_toxic_pred,less_toxic_pred,target)
        self.log('val_margin_loss',loss,on_step=False, on_epoch=True,prog_bar=True)
        return loss 


In [6]:
@torch.no_grad()
def predict(model, dataloader,modelpaths, device):
    preds = []
    final_preds = []
    for path in modelpaths:
        model_infer =model.load_from_checkpoint("../input/fork-of-pytorch-lightning-toxic-bert/"+path)
        model_infer.to(device)
        model_infer.freeze()
        model_infer.eval()
        print(f'predicting on {path}')
        bar = tqdm(enumerate(infer_dataloader), total=len(infer_dataloader))
        for step, data in bar:
            ids = data['text_ids'].to(device, dtype = torch.long)
            mask = data['text_mask'].to(device, dtype = torch.long)
            outputs = model_infer(ids, mask)
            preds.append(outputs.view(-1).cpu().detach().numpy()) 
        preds = np.concatenate(preds)
        final_preds.append(preds)
        preds = []
        gc.collect()
    final_preds = np.array(final_preds)
    final_preds = np.mean(final_preds, axis=0)
    return final_preds





In [7]:
model = jigsaw_toxicbert()
infer_datset = JigsawDataset(df,CONFIG.tokenizer,CONFIG.max_len)
infer_dataloader = DataLoader(infer_datset, batch_size=CONFIG.val_batch_size,num_workers=2, shuffle=False, pin_memory=True)
preds = predict(model,infer_dataloader,CONFIG.models,CONFIG.device) 

predicting on toxicbert_val_margin_loss=0.02031_fold_0.ckpt


100%|██████████| 118/118 [00:29<00:00,  4.07it/s]


predicting on toxicbert_val_margin_loss=0.01921_fold_1.ckpt


100%|██████████| 118/118 [00:28<00:00,  4.20it/s]


predicting on toxicbert_val_margin_loss=0.02067_fold_1.ckpt


100%|██████████| 118/118 [00:28<00:00,  4.20it/s]


predicting on toxicbert_val_margin_loss=0.01921_fold_4.ckpt


100%|██████████| 118/118 [00:28<00:00,  4.20it/s]


predicting on toxicbert_val_margin_loss=0.02136_fold_2.ckpt


100%|██████████| 118/118 [00:28<00:00,  4.18it/s]


predicting on toxicbert_val_margin_loss=0.01999_fold_3.ckpt


100%|██████████| 118/118 [00:28<00:00,  4.17it/s]


predicting on toxicbert_val_margin_loss=0.01929_fold_4.ckpt


100%|██████████| 118/118 [00:28<00:00,  4.18it/s]


predicting on toxicbert_val_margin_loss=0.02196_fold_2.ckpt


100%|██████████| 118/118 [00:28<00:00,  4.19it/s]


predicting on toxicbert_val_margin_loss=0.02033_fold_3.ckpt


100%|██████████| 118/118 [00:28<00:00,  4.19it/s]


predicting on toxicbert_val_margin_loss=0.02110_fold_0.ckpt


100%|██████████| 118/118 [00:28<00:00,  4.18it/s]


In [8]:
submit = pd.DataFrame()
submit["comment_id"] = df["comment_id"]
submit["score"] = preds
submit["score"] = submit['score'].rank(method='first')

In [9]:
submit.to_csv("submission.csv", index=False)