# 1. Config

In [None]:
!ls /kaggle/input

In [None]:
!unzip /kaggle/input/jigsaw-toxic-comment-classification-challenge/train.csv.zip
!unzip /kaggle/input/jigsaw-toxic-comment-classification-challenge/test.csv.zip
!unzip /kaggle/input/jigsaw-toxic-comment-classification-challenge/test_labels.csv.zip

In [None]:
MODEL_PATH = '/kaggle/input/jigsaw-get-zsc-models/DeBERTa-v3-base-mnli-fever-anli'
DATA_PATH = '/kaggle/input/jigsaw-filter-pseudolabel/validation_data_std_2.0.csv'
VAL_PCT = 0.1
SEED = 42

BATCH_SIZE = 16
LR = 1e-4 # from auto LR finder suggestion
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.06
EPOCH = 2
STEPS = 152205 // BATCH_SIZE + 1 # first INT value is total training data

In [None]:
import torch
import pandas as pd
import pytorch_lightning as pl
import transformers
import sklearn
import gc

print(torch.__version__)
print(pd.__version__)
print(transformers.__version__)
print(pl.__version__)
print(sklearn.__version__)

pl.seed_everything(SEED, workers=True)

# 2. Model

In [None]:
from transformers import AutoModel, AutoTokenizer, AutoConfig

m = AutoModel.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
config = AutoConfig.from_pretrained(MODEL_PATH)

In [None]:
from transformers.models.deberta_v2.modeling_deberta_v2 import StableDropout, ContextPooler
from torch.optim.lr_scheduler import OneCycleLR

class JigsawModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.deberta = model
        self.dense = torch.nn.Sequential(
            torch.nn.Linear(768, 768),
            StableDropout(drop_prob=0.1), # original 0.0 ??
            torch.nn.Linear(768, 1)
        )
        self.loss = torch.nn.BCEWithLogitsLoss()
    
    def forward(self, ids, mask, token_type_ids):
        out = self.deberta(ids, attention_mask = mask, token_type_ids = token_type_ids)
        out = out.last_hidden_state[:, 0]
        out = self.dense(out)
        out = torch.reshape(out, (-1, ))

        return out

    def configure_optimizers(self):
        optimizer = transformers.AdamW(self.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
#         scheduler = transformers.get_linear_schedule_with_warmup(
#             optimizer,
#             int(EPOCH * STEPS * WARMUP_RATIO),
#             int(EPOCH * STEPS * (1 - WARMUP_RATIO))
#         )
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=LR, steps_per_epoch=STEPS, epochs=EPOCH,
            div_factor=10, final_div_factor=40
        )
    
        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        ids = batch['ids']
        mask = batch['mask']
        token_type_ids = batch['token_type_ids']
        label = batch['label']

        out = self(ids, mask, token_type_ids)
        loss = self.loss(out, label)
        
        return loss

    def validation_step(self, batch, batch_idx):
        ids = batch['ids']
        mask = batch['mask']
        token_type_ids = batch['token_type_ids']
        label = batch['label']

        out = self(ids, mask, token_type_ids)
        loss = self.loss(out, label)
        self.log("val_loss", loss, prog_bar=True)
        
        return loss

    def training_epoch_end(self, outputs):
        # manually print loss on each epoch
        losses = [d['loss'] for d in outputs]
        avg_loss = torch.stack(losses).mean()
        print(f'Epoch #{self.current_epoch} | loss: {avg_loss}')

    def validation_epoch_end(self, outputs):
        # manually print loss on each epoch
        avg_loss = torch.stack(outputs).mean()
        print(f'Epoch #{self.current_epoch} | val_loss: {avg_loss}')
        self.log("val_loss", avg_loss)

        
    def predict_step(self, batch, batch_idx):
        ids = batch['ids']
        mask = batch['mask']
        token_type_ids = batch['token_type_ids']

        out = self(ids, mask, token_type_ids)
        out = torch.sigmoid(out)

        return out

    
model = JigsawModel(m)

# 3. Data

In [None]:
# load
df_raw = pd.read_csv(DATA_PATH)
# remove duplicate
df_raw = df_raw[df_raw.duplicated(subset=['less_toxic']) == False]
df_raw = df_raw[df_raw.duplicated(subset=['more_toxic']) == False]
# create fresh dataframe
df = pd.DataFrame(columns=['text', 'label'])
# select text without score "-"
df = pd.concat(
    [
        df,
        (
            df_raw.query("less_toxic_score != '-'")
                  .loc[:, ['less_toxic', 'less_toxic_score']]
                  .rename(columns={'less_toxic': 'text', 'less_toxic_score': 'label'})
        ),
        (
            df_raw.query("more_toxic_score != '-'")
                  .loc[:, ['more_toxic', 'more_toxic_score']]
                  .rename(columns={'more_toxic': 'text', 'more_toxic_score': 'label'})
        ),        
    ],
    axis=0
)
# set label data type
df['label'] = df['label'].astype('float64')
# show
df

In [None]:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
df2_text = pd.read_csv('./test.csv')
df2_label = pd.read_csv('./test_labels.csv')
# join
df2 = pd.concat([df2_text, df2_label], axis=1)
# select used column
df2 = df2[['comment_text', 'toxic', 'severe_toxic']]
# remove row with value -1
df2 = df2[
    (df2['toxic'] != -1)
    &
    (df2['severe_toxic'] != -1)
]
# set label with smoothing
label = []
for row in df2.iterrows():
    row = row[1]
    if row['toxic'] == 0 and row['severe_toxic'] == 0:
        label.append(0.05)
    elif row['severe_toxic'] == 1:
        label.append(0.95)
    else:
        label.append(0.5)
df2['label'] = label
# under sample label 0.0
# df2 = pd.concat([
#     df2[
#         (df2['label'] == 0.5)
#         |
#         (df2['label'] == 1.0)
#     ],
#     df2[df2['label'] == 0.0].sample(n=20000, random_state=SEED)
# ])
df2

In [None]:
# https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification
df3_pub = pd.read_csv('/kaggle/input/jigsaw-unintended-bias-in-toxicity-classification/test_public_expanded.csv')
df3_prv = pd.read_csv('/kaggle/input/jigsaw-unintended-bias-in-toxicity-classification/test_private_expanded.csv')
# join
df3 = pd.concat([df3_pub, df3_prv], axis=0)
# select used column
df3 = df3[['comment_text', 'toxicity', 'severe_toxicity']]
# undersample row with value 0.0 and smooth it to 0.05
df3 = pd.concat([
    df3[df3['toxicity'] > 0],
    (
        df3.query('toxicity == 0.0')
           .sample(n=13500, random_state=SEED)
           .replace(0.0, 0.05)
    )
], axis=0)
# severe_toxicity max score is 0.4, so multiply it by 2.5
df3['severe_toxicity'] = df3['severe_toxicity'] * 2.5
# set label
label = []
for row in df3.iterrows():
    row = row[1]
    if row['severe_toxicity'] > 0 and row['severe_toxicity'] > row['toxicity']:
        label.append(row['severe_toxicity'])
    else:
        label.append(row['toxicity'] / 2)
df3['label'] = label
# undersample label <= 0.4
# df3 = pd.concat([
#     df3[df3['label'] <= 0.1].sample(n=9000, random_state=SEED),
#     df3[(df3['label'] > 0.1) & (df3['label'] <= 0.2)].sample(n=3000, random_state=SEED),
#     df3[(df3['label'] > 0.2) & (df3['label'] <= 0.3)].sample(n=3000, random_state=SEED),
#     df3[(df3['label'] > 0.3) & (df3['label'] <= 0.4)].sample(n=3000, random_state=SEED),
#     df3[df3['label'] > 0.4]
# ])
df3

In [None]:
X = (
    df['text'].tolist() +
    df2['comment_text'].tolist() +
    df3['comment_text'].tolist()    
)
y = (
    df['label'].tolist() +
    df2['label'].tolist() +
    df3['label'].tolist()
)

In [None]:
!rm train.csv test.csv test_labels.csv

del m
del df_raw
del df
del df2_text
del df2_label
del df2
del df3_pub
del df3_prv
del df3
del label
gc.collect()

In [None]:
from sklearn.model_selection import train_test_split

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, text, label=None):
        self.tokenizer = tokenizer
        self.text = text
        self.label = label
    
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, idx):
        inputs = self.tokenizer(
            self.text[idx],
            truncation=True,
            padding='max_length',
            max_length=512
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]
        
        if self.label is None:
            return {
                'ids': torch.tensor(ids, dtype=torch.long),
                'mask': torch.tensor(mask, dtype=torch.long),
                'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long)
            }            
        else:
            return {
                'ids': torch.tensor(ids, dtype=torch.long),
                'mask': torch.tensor(mask, dtype=torch.long),
                'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
                'label': torch.tensor(self.label[idx], dtype=torch.float)
            }

class JigsawDM(pl.LightningDataModule):
    def __init__(self, tokenizer, X, y, batch_size):
        super().__init__()
        self.tokenizer = tokenizer
        self.X = X
        self.y = y
        self.batch_size = batch_size

    def setup(self, stage=None):
        X_train, X_val, y_train, y_val = train_test_split(
            self.X, self.y, test_size=VAL_PCT, random_state=SEED
        )
        
        self.train_ds = TextDataset(self.tokenizer, X_train, y_train)
        self.val_ds = TextDataset(self.tokenizer, X_val, y_val)

        del self.X
        del self.y
        del X_train
        del y_train
        del X_val
        del y_val
        gc.collect()

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=2, pin_memory=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
dm = JigsawDM(tokenizer, X, y, batch_size=BATCH_SIZE)

# 4. Train

In [None]:
from pytorch_lightning.callbacks import StochasticWeightAveraging, ModelCheckpoint

trainer = pl.Trainer(
    gpus=1,
    precision=16,
    max_epochs=EPOCH,
    default_root_dir='./trainer_cp',
    log_every_n_steps=10,
    callbacks=[
        StochasticWeightAveraging(swa_epoch_start=1, annealing_epochs=1),
        ModelCheckpoint(dirpath='./model_cp', filename='jigsaw-debertav3-{epoch:02d}-{val_loss:.4f}', monitor='val_loss', save_top_k=3)
    ]
)

trainer.fit(model, datamodule=dm)