### Imports

In [None]:
from __future__ import annotations

import functools
import itertools
import random
import statistics
import typing as t

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as torch_f
import typing_extensions as t_ext
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from scipy.stats import rankdata

### Datasets

In [None]:
class _TokenizedText(t_ext.TypedDict):
    input_ids: torch.Tensor
    attention_mask: torch.Tensor


def _preprocess_tokenizer_output(output: t.Dict[str, t.Any]) -> _TokenizedText:
    return {
        'input_ids': torch.tensor(output['input_ids']),
        'attention_mask': torch.tensor(output['attention_mask']),
    }


def _split_str_to_chunk_list(s: str, chunk_size: int) -> t.List[str]:
    chunk_list = []
    chunk = []
    for token in s.split(' '):
        chunk.append(token)
        if len(chunk) >= chunk_size:
            chunk_list.append(' '.join(chunk))
            chunk.clear()
    if chunk:
        chunk_list.append(' '.join(chunk))
    return chunk_list


def predict_collate_fn(
        sample_list: t.List[t.Tuple[str, _TokenizedText]]
        ) -> t.Tuple[t.List[str], _TokenizedText, t.List[slice]]:
    curr_pos = 0

    idx_list: t.List[str] = []
    input_ids_list = []
    attention_mask_list = []
    slice_list: t.List[slice] = []
    
    for sample in sample_list:
        idx_list.append(sample[0])
        input_ids, attention_mask = sample[1]['input_ids'], sample[1]['attention_mask']
        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        slice_list.append(slice(curr_pos, curr_pos + input_ids.shape[0]))
        curr_pos += input_ids.shape[0]

    tokenized_collated: _TokenizedText = {
        'input_ids': torch.cat(input_ids_list, dim=0),
        'attention_mask': torch.cat(attention_mask_list, dim=0),
    }

    return idx_list, tokenized_collated, slice_list


class PredictDataset(Dataset):

    def __init__(self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_len: int) -> None:
        super().__init__()
        self._df = df
        self._tokenizer = tokenizer
        self._max_len = max_len

    def __len__(self) -> int:
        return len(self._df)

    def __getitem__(self, idx: int) -> t.Tuple[str, _TokenizedText]:
        record = self._df.iloc[idx]
        comment_id, text = str(record['comment_id']), str(record['text'])

        input_ids_list, attention_mask_list = [], []
        for chunk in _split_str_to_chunk_list(text, chunk_size=self._max_len):
            tokenized_chunk = _preprocess_tokenizer_output(self._tokenizer(
                chunk,
                add_special_tokens=True,
                truncation=True,
                padding='max_length',
                max_length=self._max_len,
                return_attention_mask=True))  # type: ignore
            input_ids_list.append(tokenized_chunk['input_ids'])
            attention_mask_list.append(tokenized_chunk['attention_mask'])

        tokenized_text: _TokenizedText = {
            'input_ids': torch.stack(input_ids_list, dim=0),
            'attention_mask': torch.stack(attention_mask_list, dim=0),
        }

        return comment_id, tokenized_text

### Models

#### Base

In [None]:
class Model(torch.nn.Module):

    def predict_scores(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class ModelConfig(t.NamedTuple):
    name: str
    model: Model
    tokenizer: AutoTokenizer


def import_checkpoint(model: torch.nn.Module, checkpoint: str, device: str):
    state_dict = torch.load(checkpoint, map_location=device)
    model.load_state_dict(state_dict)

#### CCC-2017

In [None]:
class _AttentionRegressor(torch.nn.Module):

    def __init__(self, in_features: int) -> None:
        super().__init__()
        self.attention = torch.nn.Linear(in_features=in_features, out_features=in_features, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        weight = self.attention(x)
        return (x * torch_f.softmax(weight, dim=1)).sum(dim=1)


class _CCC2017M1Model(Model):

    def __init__(self, checkpoint: str, output_logits: int, num_classes: int):
        super(Model, self).__init__()
        self.encoder = AutoModel.from_pretrained(checkpoint, return_dict=False)
        self.blind_regressor = torch.nn.Sequential(
            torch.nn.Linear(output_logits, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 1),
            torch.nn.Sigmoid())
        self.classifier = torch.nn.Sequential(
            # torch.nn.LayerNorm(output_logits),
            torch.nn.Linear(output_logits, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, num_classes))
        self.regressor = _AttentionRegressor(in_features=num_classes + 1)

    def forward_scores(self, blind_reg_output: torch.Tensor, label_preds: torch.Tensor) -> torch.Tensor:
        return self.regressor(torch.cat([torch.sigmoid(label_preds), blind_reg_output], dim=1))

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> t.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        _, pooled_output = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask)
        blind_reg_output = self.blind_regressor(pooled_output)
        label_preds = self.classifier(pooled_output)
        scores = self.forward_scores(blind_reg_output, label_preds)
        return blind_reg_output, label_preds, scores

    def predict_scores(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        return self.forward(input_ids, attention_mask)[2]


def load_ccc2017_m1(device: str) -> ModelConfig:
    model = _CCC2017M1Model('../input/roberta-base', 768, 6)
    import_checkpoint(model, '../input/jt-models-to-ensemble/ccc-2017-multilabel-v3-cls-att-blind-reg.pt', device=device)
    return ModelConfig(
        name='ccc-2017-multilabel-v3-cls-att-blind-reg',
        model=model,
        tokenizer=AutoTokenizer.from_pretrained('../input/roberta-base'))

In [None]:
class _WeightedAverageLinearRegressor(torch.nn.Linear):

    def __init__(self, in_features: int, device: t.Optional[str] = None, dtype: t.Optional[str] = None):
        super().__init__(in_features=in_features, out_features=1, bias=False, device=device, dtype=dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch_f.linear(x, torch_f.softmax(self.weight, dim=1), self.bias)


class _CCC2017M3Model(Model):
    """
    ccc-2017-multilabel-harder-cls-loss_0p5-v2-valfreq_dynamic_v1
    """

    def __init__(self, checkpoint: str, output_logits: int, num_classes: int):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(checkpoint, return_dict=False)
        self.classifier = torch.nn.Sequential(
            # torch.nn.LayerNorm(output_logits),
            torch.nn.Linear(output_logits, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, num_classes))
        self.regressor = _WeightedAverageLinearRegressor(in_features=num_classes)

    def forward_scores(self, label_preds: torch.Tensor) -> torch.Tensor:
        return self.regressor(label_preds)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> t.Tuple[torch.Tensor, torch.Tensor]:
        _, pooled_output = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask)
        label_preds = self.classifier(pooled_output)
        scores = self.forward_scores(torch.sigmoid(label_preds))
        return label_preds, scores

    def predict_scores(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        return self.forward(input_ids, attention_mask)[1]


def load_ccc2017_m3(device: str) -> ModelConfig:
    model = _CCC2017M3Model('../input/roberta-base', 768, 6)
    import_checkpoint(model, '../input/jt-models-to-ensemble/ccc-2017-multilabel-harder-cls-loss_0p5-v2-valfreq_dynamic_v1.pt', device=device)
    return ModelConfig(
        name='ccc-2017-multilabel-harder-cls-loss_0p5-v2-valfreq_dynamic_v1',
        model=model,
        tokenizer=AutoTokenizer.from_pretrained('../input/roberta-base'))

#### Ruddit

In [None]:
class _RudditM1Model(Model):

    def __init__(self, checkpoint: str, output_logits: int, dropout: float):
        super().__init__()
        self.bert = AutoModel.from_pretrained(checkpoint, return_dict=False)
        self.regressor = torch.nn.Sequential(
            # torch.nn.LayerNorm(output_logits),
            torch.nn.Linear(output_logits, 1),
            torch.nn.Sigmoid(),
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask)
        return self.regressor(pooled_output)

    def predict_scores(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        return self.forward(input_ids, attention_mask)


def load_ruddit_m1(device: str) -> ModelConfig:
    model = _RudditM1Model('../input/roberta-base', 768, 0.6)
    import_checkpoint(model, '../input/jt-models-to-ensemble/ruddit-v3-mse-2ep-pure_reg.pt', device=device)
    return ModelConfig(
        name='ruddit-v3-mse-2ep-pure_reg',
        model=model,
        tokenizer=AutoTokenizer.from_pretrained('../input/roberta-base'))

In [None]:
class _RudditM2Model(Model):

    def __init__(self, checkpoint: str, output_logits: int, dropout: float):
        super().__init__()
        self.bert = AutoModel.from_pretrained(checkpoint, return_dict=False)
        self.regressor = torch.nn.Sequential(
            # torch.nn.LayerNorm(output_logits),
            torch.nn.Linear(output_logits, 256),
            torch.nn.Tanh(),
            torch.nn.Linear(256, 1),
            torch.nn.Sigmoid(),
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask)
        return self.regressor(pooled_output)

    def predict_scores(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        return self.forward(input_ids, attention_mask)


def load_ruddit_m2(device: str) -> ModelConfig:
    model = _RudditM2Model('../input/unbiasedtoxicroberta', 768, 0.6)
    import_checkpoint(model, '../input/jt-models-to-ensemble/ruddit-v3-mse-2ep-pure_reg-unbiased_toxic_roberta-2layer_reg.pt', device=device)
    return ModelConfig(
        name='ruddit-v3-mse-2ep-pure_reg-unbiased_toxic_roberta-2layer_reg',
        model=model,
        tokenizer=AutoTokenizer.from_pretrained('../input/unbiasedtoxicroberta'))

#### Wiki Talk Labels

In [None]:
class _WikiTalkLabelsM1Model(Model):

    def __init__(self, checkpoint: str, output_logits: int, dropout: float):
        super(Model, self).__init__()
        self.bert = AutoModel.from_pretrained(checkpoint, return_dict=False)
        self.regressor = torch.nn.Sequential(
            torch.nn.Linear(output_logits, 256),
            torch.nn.Tanh(),
            torch.nn.Linear(256, 1),
            torch.nn.Sigmoid(),
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask)
        return self.regressor(pooled_output)

    def predict_scores(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        return self.forward(input_ids, attention_mask)


def load_wiki_talk_labels_m1(device: str) -> ModelConfig:
    model = _WikiTalkLabelsM1Model('../input/roberta-base', 768, 0.6)
    import_checkpoint(model, '../input/jt-models-to-ensemble/wiki-talk-labels-v1-1ep.pt', device=device)
    return ModelConfig(
        name='wiki-talk-labels-v1-1ep',
        model=model,
        tokenizer=AutoTokenizer.from_pretrained('../input/roberta-base'))

### Inference

In [None]:
def do_prediction_iteration(
        in_df: pd.DataFrame,
        batch_size: int,
        model_getter: t.Callable[[str], ModelConfig],
        max_len: int,
        num_workers: int,
        device: str,) -> np.ndarray:
    model_config = model_getter(device)
    model = model_config.model.to(device)
    dataset = PredictDataset(
        df=in_df,
        tokenizer=model_config.tokenizer,
        max_len=max_len)
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=predict_collate_fn,
        pin_memory=device.startswith('cuda'))
    model.eval()
    score_list = []
    with torch.no_grad():
        for comment_id_list, tokenized_text, slice_list in tqdm(data_loader, desc='Prediction'):
            scores_tensor = model.predict_scores(
                tokenized_text['input_ids'].to(device),
                tokenized_text['attention_mask'].to(device),)
            scores_tensor = torch.cat([torch.max(scores_tensor[s], dim=0, keepdim=True)[0] for s in slice_list], dim=0)
            score_list.extend(scores_tensor.flatten().tolist())
    return np.array(score_list)

In [None]:
eval_df = pd.read_csv('../input/jigsaw-toxic-severity-rating/comments_to_score.csv')

In [None]:
ccc2017_m1_score_arr = do_prediction_iteration(
    in_df=eval_df,
    batch_size=16,
    model_getter=load_ccc2017_m1,
    num_workers=2,
    max_len=256,
    device='cuda')

In [None]:
ccc2017_m3_score_arr = do_prediction_iteration(
    in_df=eval_df,
    batch_size=16,
    model_getter=load_ccc2017_m3,
    num_workers=2,
    max_len=256,
    device='cuda')

In [None]:
ruddit_m1_score_arr = do_prediction_iteration(
    in_df=eval_df,
    batch_size=16,
    model_getter=load_ruddit_m1,
    num_workers=2,
    max_len=256,
    device='cuda')

In [None]:
ruddit_m2_score_arr = do_prediction_iteration(
    in_df=eval_df,
    batch_size=16,
    model_getter=load_ruddit_m2,
    num_workers=2,
    max_len=256,
    device='cuda')

In [None]:
wiki_talk_labels_m1_score_arr = do_prediction_iteration(
    in_df=eval_df,
    batch_size=16,
    model_getter=load_wiki_talk_labels_m1,
    num_workers=2,
    max_len=256,
    device='cuda')

In [None]:
def ensemble_score_arr_list(score_arr_list: t.List[np.ndarray]) -> np.ndarray:
    score_arr_list = [rankdata(score_arr, method='ordinal') for score_arr in score_arr_list]
    return np.stack(score_arr_list, axis=0).mean(axis=0)

In [None]:
score_arr = ensemble_score_arr_list([
    ccc2017_m1_score_arr,
    ccc2017_m3_score_arr,
    ruddit_m1_score_arr,
    ruddit_m2_score_arr,
    wiki_talk_labels_m1_score_arr,
])

In [None]:
pd.DataFrame([
    {'comment_id': comment_id, 'score': score}
    for comment_id, score in zip(eval_df['comment_id'].tolist(), score_arr.tolist())
]).to_csv('/kaggle/working/submission.csv', index=False)

In [None]:
!head -n 4 /kaggle/working/submission.csv