# 1. Config

In [None]:
!ls /kaggle/input/jigsaw-train-debertav3/model_cp

In [None]:
MODEL_PATH = '/kaggle/input/jigsaw-get-zsc-models/DeBERTa-v3-base-mnli-fever-anli'
WEIGHT_PATH = '/kaggle/input/jigsaw-train-debertav3/model_cp/jigsaw-debertav3-epoch=01-val_loss=0.3983.ckpt'
SEED = 42

BATCH_SIZE = 16

In [None]:
import torch
import pandas as pd
import pytorch_lightning as pl
import transformers
import itertools

print(torch.__version__)
print(pd.__version__)
print(transformers.__version__)
print(pl.__version__)

pl.seed_everything(SEED, workers=True)

# 2. Model

In [None]:
from transformers import AutoModel, AutoTokenizer, AutoConfig

m = AutoModel.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
config = AutoConfig.from_pretrained(MODEL_PATH)

In [None]:
from transformers.models.deberta_v2.modeling_deberta_v2 import StableDropout, ContextPooler

class JigsawModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.deberta = model
        self.dense = torch.nn.Sequential(
            torch.nn.Linear(768, 768),
            StableDropout(drop_prob=0.1), # original 0.0 ??
            torch.nn.Linear(768, 1)
        )
        self.loss = torch.nn.BCEWithLogitsLoss()
    
    def forward(self, ids, mask, token_type_ids):
        out = self.deberta(ids, attention_mask = mask, token_type_ids = token_type_ids)
        out = out.last_hidden_state[:, 0]
        out = self.dense(out)
        out = torch.reshape(out, (-1, ))

        return out

    def configure_optimizers(self):
        optimizer = transformers.AdamW(self.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer,
            int(EPOCH * STEPS * WARMUP_RATIO),
            int(EPOCH * STEPS * (1 - WARMUP_RATIO))
        )
    
        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        ids = batch['ids']
        mask = batch['mask']
        token_type_ids = batch['token_type_ids']
        label = batch['label']

        out = self(ids, mask, token_type_ids)
        loss = self.loss(out, label)
        
        return loss

    def validation_step(self, batch, batch_idx):
        ids = batch['ids']
        mask = batch['mask']
        token_type_ids = batch['token_type_ids']
        label = batch['label']

        out = self(ids, mask, token_type_ids)
        loss = self.loss(out, label)
        self.log("val_loss", loss, prog_bar=True)
        
        return loss

    def training_epoch_end(self, outputs):
        # manually print loss on each epoch
        losses = [d['loss'] for d in outputs]
        avg_loss = torch.stack(losses).mean()
        print(f'Epoch #{self.current_epoch} | loss: {avg_loss}')

    def validation_epoch_end(self, outputs):
        # manually print loss on each epoch
        avg_loss = torch.stack(outputs).mean()
        print(f'Epoch #{self.current_epoch} | val_loss: {avg_loss}')

        
    def predict_step(self, batch, batch_idx):
        ids = batch['ids']
        mask = batch['mask']
        token_type_ids = batch['token_type_ids']

        out = self(ids, mask, token_type_ids)
        out = torch.sigmoid(out)

        return out

    
model = JigsawModel(m)
model.load_from_checkpoint(WEIGHT_PATH, model=m)

In [None]:
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, text, label=None):
        self.tokenizer = tokenizer
        self.text = text
        self.label = label
    
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, idx):
        inputs = self.tokenizer(
            self.text[idx],
            truncation=True,
            padding='max_length',
            max_length=512
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]
        
        if self.label is None:
            return {
                'ids': torch.tensor(ids, dtype=torch.long),
                'mask': torch.tensor(mask, dtype=torch.long),
                'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long)
            }            
        else:
            return {
                'ids': torch.tensor(ids, dtype=torch.long),
                'mask': torch.tensor(mask, dtype=torch.long),
                'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
                'label': torch.tensor(self.label[idx], dtype=torch.float)
            }

# 3. Predict

In [None]:
df_test = pd.read_csv('/kaggle/input/jigsaw-toxic-severity-rating/comments_to_score.csv')

test_ds = TextDataset(tokenizer, df_test['text'].tolist())
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [None]:
trainer = pl.Trainer(
    gpus=1,
    precision=16
)

y_pred = trainer.predict(model, test_dl)
y_pred = [t.numpy() for t in y_pred]
y_pred = list(itertools.chain(*y_pred))

In [None]:
df_submission = df_test.copy()
del df_submission['text']
df_submission['score'] = y_pred
df_submission.to_csv('submission.csv', index=False)
df_submission