In [None]:
import pathlib
from typing import Any, Callable, Dict, Tuple, List

import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer

# Wandb login:
from kaggle_secrets import UserSecretsClient
import wandb
user_secrets = UserSecretsClient()
secret_value = user_secrets.get_secret("wandb_api_key")
wandb.login(key=secret_value)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Kudos to [this Kaggle kernel](https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta).

In [None]:
ROOT_PATH = pathlib.Path("/kaggle/input/jigsaw-multilingual-toxic-comment-classification/")
MODEL = "distilbert-base-multilingual-cased"
BATCH_SIZE = 32
EPOCHS = 1
LR = 1e-5
MAX_DOC_LENGTH = 256

## Data

In [None]:
train_df = pd.read_csv(ROOT_PATH / "jigsaw-toxic-comment-train.csv")
valid_df = pd.read_csv(ROOT_PATH / "validation.csv")
test_df = pd.read_csv(ROOT_PATH / "test.csv").rename(columns={"content": "comment_text"})
train_df.sample(5)

In [None]:
train_df.sample(5)

In [None]:
train_df["toxic"].mean(), valid_df["toxic"].mean()

## Data Preparation for PyTorch

In [None]:
class TestCommentsData(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.comments = df["comment_text"].values
        
    def __len__(self) -> int:
        return len(self.comments)
    
    def __getitem__(self, idx):
        return self.comments[idx]

class CommentsData(TestCommentsData):
    def __init__(self, df: pd.DataFrame):
        super().__init__(df)
        self.toxic = df["toxic"].values
    
    def __getitem__(self, idx):
        return self.comments[idx], self.toxic[idx]
    
train_ds = CommentsData(train_df)
valid_ds = CommentsData(valid_df)
test_ds = TestCommentsData(test_df)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4, pin_memory=True)

In [None]:
x, y = next(iter(train_dl))
len(x), y.shape

## Model

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = AutoModel.from_pretrained(MODEL)
        self.head = nn.Linear(self.base.config.dim, 1)
        
    def forward(self, x: Dict[str, torch.LongTensor]) -> torch.FloatTensor:
        sequence_output = self.base(**x).last_hidden_state # shape of BS x (SEQ_LEN + 1) x 768
        cls_token = sequence_output[:, 0, :]
        return self.head(cls_token)

In [None]:
class LightningModel(pl.LightningModule):
    def __init__(self, model: nn.Module, tokenizer, loss_fn: Callable, lr: float, thresh: float=0.5):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.loss_fn = loss_fn
        self.lr = lr
        self.thresh = thresh
        
    def common_step(self, batch):
        x, y = batch
        tokenized_x = self.tokenizer(
            x,
            max_length=MAX_DOC_LENGTH,
            truncation=True,
            padding=True,
            return_tensors="pt",
        )
        tokenized_x_device = {k: v.to(self.device) for k, v in tokenized_x.items()}
        logits = self.model(tokenized_x_device).squeeze()
        loss = self.loss_fn(logits, y.float())
        probabilities = torch.sigmoid(logits)
        y_pred = (probabilities > self.thresh).long()
        accuracy = (y_pred == y).float().mean()

        return loss, accuracy
        
    def training_step(self, batch: Tuple[torch.FloatTensor, torch.LongTensor], *args: List[Any]):
        loss, accuracy = self.common_step(batch)
        self.log("training_loss", loss, on_step=True, on_epoch=True)
        self.log("training_accuracy", accuracy, on_step=True, on_epoch=True)
        
        return loss
        
#     def on_epoch_end(self, *args):
#         if self.current_epoch == 0:
#             for p in self.model.base.parameters():
#                 p.requires_grad = True
        
    def validation_step(self, batch: Tuple[torch.FloatTensor, torch.LongTensor], *args: List[Any]):
        loss, accuracy = self.common_step(batch)
        self.log("validation_loss", loss, on_step=False, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_step=False, on_epoch=True)
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [None]:
!mkdir -p /kaggle/working/logs
model = Model()
tokenizer = AutoTokenizer.from_pretrained(MODEL)
loss_fn = nn.BCEWithLogitsLoss()
lightning_model = LightningModel(model, tokenizer, loss_fn, LR)

logger = WandbLogger("toxic comments - pt", "/kaggle/working/logs/", project="Toxic Comments")
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    gpus=torch.cuda.device_count(),
    gradient_clip_val=1.0,
    logger=logger,
    precision=16,
)
trainer.fit(lightning_model, train_dl, valid_dl)

## Model

- Transformers: https://jalammar.github.io/illustrated-transformer/
- BERT: https://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/

## Submission

In [None]:
with torch.no_grad():
    model = model.eval().to(device)
    y_preds = []
    for x in tqdm(test_dl):
        tokenized_x = tokenizer(
            x,
            max_length=MAX_DOC_LENGTH,
            truncation=True,
            padding=True,
            return_tensors="pt",
        )
        tokenized_x_device = {k: v.to(device) for k, v in tokenized_x.items()}
        logits = model(tokenized_x_device).squeeze()
        probabilities = torch.sigmoid(logits)
        y_preds.append(probabilities.cpu())

sub = pd.read_csv(ROOT_PATH / "sample_submission.csv")
sub['toxic'] = torch.cat(y_preds).numpy()
sub.to_csv('submission.csv', index=False)