In [None]:
import os
import gc
import copy
import time
import random
import string
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from glob import glob
from tqdm.notebook import tqdm

from collections import defaultdict
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold, KFold

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AutoConfig, AdamW

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
df_CCC = pd.read_csv("../input/context-toxicitymaster/CCC.csv")
df_gc = pd.read_csv("../input/context-toxicitymaster/gc.csv")
df_gn = pd.read_csv("../input/context-toxicitymaster/gn.csv")

In [None]:
df_CCC.head()

In [None]:
df_gc.head()

In [None]:
df_gn.head()

In [None]:
text_list = list(df_CCC["text"].values) + list(df_gc["text"].values) + list(df_gn["text"].values)
print(len(text_list))
text_list = list(np.unique(text_list))
print(len(text_list))
target_dataset = pd.DataFrame([])
target_dataset["comment_text"] = text_list

In [None]:
target_dataset.head()

In [None]:
class CCDataset(Dataset):
    def __init__(self, df, tokenizer, max_length):
        self.df = df
        self.max_len = max_length
        self.tokenizer = tokenizer
        self.text = df['comment_text'].values
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        text = self.text[index]
        inputs = self.tokenizer.encode_plus(
            text,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length'
        )

        ids = inputs['input_ids']
        mask = inputs['attention_mask']        
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long)
        }


@torch.no_grad()
def valid_fn(model, dataloader, device):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    PREDS = []
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        
        outputs = model(ids, mask)
        PREDS.append(outputs.view(-1).cpu().detach().numpy()) 
    
    PREDS = np.concatenate(PREDS)
    gc.collect()
    
    return PREDS


def inference(model_paths, dataloader, device):
    final_preds = []
    for i, path in enumerate(model_paths):
        model = JigsawModel(CONFIG['model_name'])
        model.to(CONFIG['device'])
        model.load_state_dict(torch.load(path))
        
        print(f"Getting predictions for model {i+1}")
        preds = valid_fn(model, dataloader, device)
        final_preds.append(preds)
    
    final_preds = np.array(final_preds)
    final_preds = np.mean(final_preds, axis=0)
    del model
    gc.collect()
    return final_preds

In [None]:
###############
# CONFIG
###############

CONFIG = dict(
    seed = 42,
    model_name = '../input/jigsaw-multilingual-toxic-xlm-roberta/model',
    test_batch_size = 128,
    max_length = 128,
    num_classes = 1,
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
)

CONFIG["tokenizer"] = AutoTokenizer.from_pretrained(CONFIG['model_name'])
set_seed(CONFIG['seed'])


MAIN_PATH = '../input/jigsaw-exp019-toxic-xlm-roberta'

MODEL_PATHS = [
    f'../input/{MAIN_PATH}/Loss-Fold-0.bin',
    f'../input/{MAIN_PATH}/Loss-Fold-1.bin',
    f'../input/{MAIN_PATH}/Loss-Fold-2.bin',
    f'../input/{MAIN_PATH}/Loss-Fold-3.bin',
    f'../input/{MAIN_PATH}/Loss-Fold-4.bin',
]


###############
# MODEL
###############

class JigsawModel(nn.Module):
    def __init__(self, model_name):
        super(JigsawModel, self).__init__()
        
        config = AutoConfig.from_pretrained(model_name)
        config.update({
            "output_hidden_states": True,
            "hidden_dropout_prob": 0.0,
            "attention_probs_dropout_prob": 0.0,
        })
        self.model = AutoModel.from_pretrained(model_name, config=config)
        self.linear = nn.Linear(768, CONFIG['num_classes'])
        
    def forward(self, ids, mask):        
        out = self.model(
            input_ids=ids,
            attention_mask=mask,
        )
        outputs = self.linear(out.last_hidden_state[:, 0, :])
        return outputs

###############
# INFERENCE
###############

validation_df = pd.read_csv('../input/jigsaw-toxic-severity-rating/validation_data.csv')
valid_comments = pd.concat([validation_df['less_toxic'], validation_df['more_toxic']]).unique()
comments = target_dataset[~target_dataset['comment_text'].isin(valid_comments)].reset_index(drop=True)
display(comments)

test_dataset = CCDataset(
    comments,
    CONFIG['tokenizer'],
    max_length=CONFIG['max_length']
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['test_batch_size'],
    num_workers=2,
    shuffle=False,
    pin_memory=True
)

preds = inference(MODEL_PATHS, test_loader, CONFIG['device'])
comments['pseudo_label'] = preds
comments.to_csv('PseudoLabelDataset.csv', index=False)

del comments, test_dataset, test_loader
gc.collect()