---
## [Jigsaw Rate Severity of Toxic Comments][1]
---

**Comments**: Thanks to previous great Notebooks for data preprocessing.

1. [☣️ Jigsaw - Incredibly Simple Naive Bayes [0.768]][2]
2. [AutoNLP for toxic ratings ;)][3]
3. [Regression Ensemble LB=0.78][4]
4. [Jigsaw Ensemble [0.86]][5]


[1]: https://www.kaggle.com/c/jigsaw-toxic-severity-rating/overview
[2]: https://www.kaggle.com/julian3833/jigsaw-incredibly-simple-naive-bayes-0-768
[3]: https://www.kaggle.com/abhishek/autonlp-for-toxic-ratings
[4]: https://www.kaggle.com/ekaterinadranitsyna/regression-ensemble-lb-0-78/notebook
[5]: https://www.kaggle.com/andrej0marinchenko/jigsaw-ensemble-0-86

# 0. Settings

In [None]:
# Import dependencies 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
%matplotlib inline

import os
import pathlib
import gc
import sys
import re
import math 
import random
import time 
import tqdm 
from tqdm import tqdm 

import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import KFold 
from sklearn.model_selection import StratifiedKFold 

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers.experimental import preprocessing

import transformers 
import datasets 

print('import done!')

In [None]:
# global config
config = {
    'model_path': '../input/roberta-base-211212',
    'tokenizer_path': '../input/roberta-base-tokenizer-211212',
    'batch_size': 8,
    'n_folds': 30,
    'nontoxic_n_factor': 0.4,
    'num_words': 3,
    'under_over_ratio_factor': 1.0,
    'clipping_score': 10.0,
    'upsampling_threshold': 4.0,
    'capped': False,
}

AUTOTUNE = tf.data.experimental.AUTOTUNE

# For reproducible results    
def seed_all(s):
    random.seed(s)
    np.random.seed(s)
    tf.random.set_seed(s)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    os.environ['PYTHONHASHSEED'] = str(s) 
    print('Seeds setted!')
global_seed = 42
seed_all(global_seed)

# 1. Data Preprocessing

### 1.1 Create train data

For training data, I used [Toxic Comment Classification Challenge][1] dataset.

[1]: https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data

In [None]:
# Extract classified text samples and clean the texts.
df = pd.read_csv('../input/jigsaw-toxic-comment-classification-challenge/train.csv')

df['toxic_label'] = (df[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']].sum(axis=1) > 0 ).astype(int)

categories = df.loc[:, 'toxic':'identity_hate'].sum()
plt.title('Category Frequency')
plt.bar(categories.index, categories.values)
plt.show()

In the previous competition the task was to perform multi-class classification. Text sample could be labeled with one or several categories or not labeled with any. Non-toxic comments represent the majority of text samples, while toxic comments are a minority class and extremely toxic comments are more rare than plain toxic.

In this competition we have to score texts based on the level of toxicity. To get a toxicity score from the previous data we can use the following approaches:
- Adjust the values in the DataFrame according to extremety of the category (for example, "toxic" and "severe toxic" should have different score) and then sum up per row values.

In [None]:
# Multiplication factors for categories.
cat_mtpl = {'toxic': 1.0, 'severe_toxic': 2.5, 'obscene': 1.0,
            'threat': 2.0, 'insult': 1.5, 'identity_hate': 2.0}

for category in cat_mtpl:
    df[category] = df[category] * cat_mtpl[category]

df['score'] = df.loc[:, 'toxic':'identity_hate'].sum(axis=1)
#df['score'] = df['score'] / df['score'].max()

bins = math.ceil(df['score'].max())

plt.hist(df['score'], bins=bins)
plt.title('Scores Distribution: Adjusted Sum')
plt.show()

### 1.2 Downsampling
The dataset is very unbalanced. Here we downsample the majority class.

In [None]:
df['toxic_label'].value_counts(normalize=True)

In [None]:
factor = config['nontoxic_n_factor']
n_samples_toxic = (df['toxic_label'] == 1).sum()
n_samples_toxic = round(n_samples_toxic * factor)

df_untoxic_undersample = df[df['toxic_label'] == 0].sample(n=n_samples_toxic, random_state=global_seed)
df_toxic = df.query('toxic_label == 1')
train_df = pd.concat([df_untoxic_undersample, df_toxic]).reset_index(drop=True)
train_df['toxic_label'].value_counts()

In [None]:
print(f'Mean toxicity score: {train_df["score"].mean()}\n'
      f'Standard deviation: {train_df["score"].std()}')

plt.hist(train_df['score'], bins=bins)
plt.title('Scores Distribution: Adjusted Sum')
plt.show()

### 1.3 Text Cleaning

In [None]:
from bs4 import BeautifulSoup
def text_cleaning(text: str) -> str:
    """Function cleans text removing special characters,
    extra spaces, embedded URL links, HTML tags and emojis.
    Code source: https://www.kaggle.com/manabendrarout/pytorch-roberta-ranking-baseline-jrstc-infer
    :param text: Original text
    :return: Preprocessed text
    """
    template = re.compile(r'https?://\S+|www\.\S+')  # website links
    text = template.sub(r'', text)

    soup = BeautifulSoup(text, 'lxml')  # HTML tags
    only_text = soup.get_text()
    text = only_text

    emoji_pattern = re.compile("["
                               u"\U0001F600-\U0001F64F"  # emoticons
                               u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                               u"\U0001F680-\U0001F6FF"  # transport & map symbols
                               u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                               u"\U00002702-\U000027B0"
                               u"\U000024C2-\U0001F251"
                               "]+", flags=re.UNICODE)
    text = emoji_pattern.sub(r'', text)

    text = re.sub(r"[^a-zA-Z\d]", " ", text)  # special characters
    text = re.sub(' +', ' ', text)  # extra spaces
    # Replace repeating characters more than 3 times to length of 3
    text = re.sub(r'([*!?\'])\1\1{2,}', r'\1\1\1', text)    
    # Add space around repeating characters
    text = re.sub(r'([*!?\']+)', r' \1 ', text)    
    # patterns with repeating characters 
    text = re.sub(r'([a-zA-Z])\1{2,}\b', r'\1\1', text)
    text = re.sub(r'([a-zA-Z])\1\1{2,}\B', r'\1\1\1', text)
    text = re.sub(r'[ ]{2,}', ' ', text)
    text = text.strip()  # spaces at the beginning and at the end of string

    return text

train_df['comment_text'] = train_df['comment_text'].apply(text_cleaning)
print('cleaning done!')

In [None]:
import nltk
from nltk.corpus import stopwords
stop = stopwords.words('english')

def text_cleaning_2(data, col):
    
    data[col] = data[col].str.replace('https?://\S+|www\.\S+', ' social medium ')      
        
    data[col] = data[col].str.lower()
    data[col] = data[col].str.replace("4", "a") 
    data[col] = data[col].str.replace("2", "l")
    data[col] = data[col].str.replace("5", "s") 
    data[col] = data[col].str.replace("1", "i") 
    data[col] = data[col].str.replace("!", "i") 
    data[col] = data[col].str.replace("|", "i") 
    data[col] = data[col].str.replace("0", "o") 
    data[col] = data[col].str.replace("l3", "b") 
    data[col] = data[col].str.replace("7", "t") 
    data[col] = data[col].str.replace("7", "+") 
    data[col] = data[col].str.replace("8", "ate") 
    data[col] = data[col].str.replace("3", "e") 
    data[col] = data[col].str.replace("9", "g")
    data[col] = data[col].str.replace("6", "g")
    data[col] = data[col].str.replace("@", "a")
    data[col] = data[col].str.replace("$", "s")
    data[col] = data[col].str.replace("#ofc", " of fuckin course ")
    data[col] = data[col].str.replace("fggt", " faggot ")
    data[col] = data[col].str.replace("your", " your ")
    data[col] = data[col].str.replace("self", " self ")
    data[col] = data[col].str.replace("cuntbag", " cunt bag ")
    data[col] = data[col].str.replace("fartchina", " fart china ")    
    data[col] = data[col].str.replace("youi", " you i ")
    data[col] = data[col].str.replace("cunti", " cunt i ")
    data[col] = data[col].str.replace("sucki", " suck i ")
    data[col] = data[col].str.replace("pagedelete", " page delete ")
    data[col] = data[col].str.replace("cuntsi", " cuntsi ")
    data[col] = data[col].str.replace("i'm", " i am ")
    data[col] = data[col].str.replace("offuck", " of fuck ")
    data[col] = data[col].str.replace("centraliststupid", " central ist stupid ")
    data[col] = data[col].str.replace("hitleri", " hitler i ")
    data[col] = data[col].str.replace("i've", " i have ")
    data[col] = data[col].str.replace("i'll", " sick ")
    data[col] = data[col].str.replace("fuck", " fuck ")
    data[col] = data[col].str.replace("f u c k", " fuck ")
    data[col] = data[col].str.replace("shit", " shit ")
    data[col] = data[col].str.replace("bunksteve", " bunk steve ")
    data[col] = data[col].str.replace('wikipedia', ' social medium ')
    data[col] = data[col].str.replace("faggot", " faggot ")
    data[col] = data[col].str.replace("delanoy", " delanoy ")
    data[col] = data[col].str.replace("jewish", " jewish ")
    data[col] = data[col].str.replace("sexsex", " sex ")
    data[col] = data[col].str.replace("allii", " all ii ")
    data[col] = data[col].str.replace("i'd", " i had ")
    data[col] = data[col].str.replace("'s", " is ")
    data[col] = data[col].str.replace("youbollocks", " you bollocks ")
    data[col] = data[col].str.replace("dick", " dick ")
    data[col] = data[col].str.replace("cuntsi", " cuntsi ")
    data[col] = data[col].str.replace("mothjer", " mother ")
    data[col] = data[col].str.replace("cuntfranks", " cunt ")
    data[col] = data[col].str.replace("ullmann", " jewish ")
    data[col] = data[col].str.replace("mr.", " mister ")
    data[col] = data[col].str.replace("aidsaids", " aids ")
    data[col] = data[col].str.replace("njgw", " nigger ")
    data[col] = data[col].str.replace("wiki", " social medium ")
    data[col] = data[col].str.replace("administrator", " admin ")
    data[col] = data[col].str.replace("gamaliel", " jewish ")
    data[col] = data[col].str.replace("rvv", " vanadalism ")
    data[col] = data[col].str.replace("admins", " admin ")
    data[col] = data[col].str.replace("pensnsnniensnsn", " penis ")
    data[col] = data[col].str.replace("pneis", " penis ")
    data[col] = data[col].str.replace("pennnis", " penis ")
    data[col] = data[col].str.replace("pov.", " point of view ")
    data[col] = data[col].str.replace("vandalising", " vandalism ")
    data[col] = data[col].str.replace("cock", " dick ")
    data[col] = data[col].str.replace("asshole", " asshole ")
    data[col] = data[col].str.replace("youi", " you ")
    data[col] = data[col].str.replace("afd", " all fucking day ")
    data[col] = data[col].str.replace("sockpuppets", " sockpuppetry ")
    data[col] = data[col].str.replace("iiprick", " iprick ")
    data[col] = data[col].str.replace("penisi", " penis ")
    data[col] = data[col].str.replace("warrior", " warrior ")
    data[col] = data[col].str.replace("loil", " laughing out insanely loud ")
    data[col] = data[col].str.replace("vandalise", " vanadalism ")
    data[col] = data[col].str.replace("helli", " helli ")
    data[col] = data[col].str.replace("lunchablesi", " lunchablesi ")
    data[col] = data[col].str.replace("special", " special ")
    data[col] = data[col].str.replace("ilol", " i lol ")
    data[col] = data[col].str.replace(r'\b[uU]\b', 'you')
    data[col] = data[col].str.replace(r"what's", "what is ")
    data[col] = data[col].str.replace(r"\'s", " is ")
    data[col] = data[col].str.replace(r"\'ve", " have ")
    data[col] = data[col].str.replace(r"can't", "cannot ")
    data[col] = data[col].str.replace(r"n't", " not ")
    data[col] = data[col].str.replace(r"i'm", "i am ")
    data[col] = data[col].str.replace(r"\'re", " are ")
    data[col] = data[col].str.replace(r"\'d", " would ")
    data[col] = data[col].str.replace(r"\'ll", " will ")
    data[col] = data[col].str.replace(r"\'scuse", " excuse ")
    data[col] = data[col].str.replace('\s+', ' ')  # will remove more than one whitespace character
#     text = re.sub(r'\b([^\W\d_]+)(\s+\1)+\b', r'\1', re.sub(r'\W+', ' ', text).strip(), flags=re.I)  # remove repeating words coming immediately one after another
    data[col] = data[col].str.replace(r'(.)\1+', r'\1\1') # 2 or more characters are replaced by 2 characters
#     text = re.sub(r'((\b\w+\b.{1,2}\w+\b)+).+\1', r'\1', text, flags = re.I)
    data[col] = data[col].str.replace("[:|♣|'|§|♠|*|/|?|=|%|&|-|#|•|~|^|>|<|►|_]", '')
    
    
    data[col] = data[col].str.replace(r"what's", "what is ")    
    data[col] = data[col].str.replace(r"\'ve", " have ")
    data[col] = data[col].str.replace(r"can't", "cannot ")
    data[col] = data[col].str.replace(r"n't", " not ")
    data[col] = data[col].str.replace(r"i'm", "i am ")
    data[col] = data[col].str.replace(r"\'re", " are ")
    data[col] = data[col].str.replace(r"\'d", " would ")
    data[col] = data[col].str.replace(r"\'ll", " will ")
    data[col] = data[col].str.replace(r"\'scuse", " excuse ")
    data[col] = data[col].str.replace(r"\'s", " ")

    # Clean some punctutations
    data[col] = data[col].str.replace('\n', ' \n ')
    data[col] = data[col].str.replace(r'([a-zA-Z]+)([/!?.])([a-zA-Z]+)',r'\1 \2 \3')
    # Replace repeating characters more than 3 times to length of 3
    data[col] = data[col].str.replace(r'([*!?\'])\1\1{2,}',r'\1\1\1')    
    # Add space around repeating characters
    data[col] = data[col].str.replace(r'([*!?\']+)',r' \1 ')    
    # patterns with repeating characters 
    data[col] = data[col].str.replace(r'([a-zA-Z])\1{2,}\b',r'\1\1')
    data[col] = data[col].str.replace(r'([a-zA-Z])\1\1{2,}\B',r'\1\1\1')
    data[col] = data[col].str.replace(r'[ ]{2,}',' ').str.strip()   
    data[col] = data[col].str.replace(r'[ ]{2,}',' ').str.strip()   
    data[col] = data[col].apply(lambda x: ' '.join([word for word in x.split() if word not in (stop)]))
    
    print('cleaning done!')
    
    return data

train_df = text_cleaning_2(train_df,'comment_text')

In [None]:
print(len(train_df))

def text_cleaning_3(data, col):
    data[col] = data[col].apply(lambda x: '' if len(x.split(' ')) < config['num_words'] else x)
    print('cleaning done!')
    return data

train_df = text_cleaning_3(train_df,'comment_text')
train_df = train_df[train_df['comment_text'] != ''].reset_index(drop=True)
print(len(train_df))

### 1.4 Score clipping
When config['clipping_score'] < 10, we clip the score.

In [None]:
train_df['score'].value_counts()

In [None]:
clipping_score = config['clipping_score']

train_df['score'] = train_df['score'].where(train_df['score'] < clipping_score, clipping_score)
train_df['score'].value_counts()

In [None]:
print(f'Mean toxicity score: {train_df["score"].mean()}\n'
      f'Standard deviation: {train_df["score"].std()}')

bins = math.ceil(train_df['score'].max())

plt.hist(train_df['score'], bins=bins)
plt.title('Scores Distribution: Adjusted Sum')
plt.show()

### 1.5 Upsampling

In [None]:
threshold = config['upsampling_threshold']

train_df_over = train_df.query(f'score >= {threshold}')
n_over = len(train_df_over)

train_df_under = train_df.query(f'score < {threshold}')
n_under = len(train_df_under)

under_over_ratio_factor = config['under_over_ratio_factor']
under_over_ratio = round(n_under / n_over * under_over_ratio_factor)
print(under_over_ratio)
train_df_over_repeat = pd.concat([train_df_over] * under_over_ratio)
train_df_upsampling = pd.concat([train_df_under, train_df_over_repeat])
train_df_upsampling = train_df_upsampling.reset_index(drop=True)
train_df = train_df_upsampling

plt.hist(train_df['score'], bins=bins)
plt.title('Scores Distribution: Adjusted Sum')
plt.show()

In [None]:
# capping the overflow
n_6 = len(train_df[train_df['score'] == 6.0])
n_0 = len(train_df[train_df['score'] == 0.0])
print(n_6, n_0)

train_df_score_6_undersample = train_df[train_df['score'] == 6.0].sample(n=n_0, random_state=global_seed)
train_df_score_not_6 = train_df.query('score != 6.0')
train_df_capped = pd.concat([train_df_score_6_undersample, train_df_score_not_6]).reset_index(drop=True)

n_6 = len(train_df_capped[train_df_capped['score'] == 6.0])
n_0 = len(train_df_capped[train_df_capped['score'] == 0.0])
print(n_6, n_0)
#train_df_capped['score'].value_counts()

if config['capped']:
    train_df = train_df_capped
    plt.hist(train_df_capped['score'], bins=bins)
    plt.title('Scores Distribution: Adjusted Sum')
    plt.show()

In [None]:
train_df['score'].value_counts()

### 1.6 Validation Data Split

In [None]:
train_df.describe()

In [None]:
n_folds = config['n_folds']

skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=global_seed)
for nfold, (train_index, val_index) in enumerate(skf.split(X=train_df.index,
                                                           y=train_df.toxic_label)):
    train_df.loc[val_index, 'fold'] = nfold
#print(train_df.groupby(['fold', train_df.toxic_label]).size())

p_fold = 0
p_train_df = train_df.query(f'fold != {p_fold}').reset_index(drop=True)
p_valid_df = train_df.query(f'fold == {p_fold}').reset_index(drop=True)

print(len(p_train_df))
print(len(p_valid_df))

In [None]:
p_train_df.describe()

In [None]:
p_valid_df.describe()

In [None]:
p_train_df = p_train_df[['comment_text', 'score']].rename(columns={'comment_text': 'text'})
p_valid_df = p_valid_df[['comment_text', 'score']].rename(columns={'comment_text': 'text'})
print('done!')

# 2. DataSet

In [None]:
train_ds = datasets.Dataset.from_pandas(p_train_df)
valid_ds = datasets.Dataset.from_pandas(p_valid_df)

print(train_ds)
print(valid_ds)

In [None]:
checkpoint = 'roberta-base'

# Downloading tokenizer (Internet required)
#tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
tokenizer = transformers.AutoTokenizer.from_pretrained(config['tokenizer_path'])

def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, max_length=128)

tokenized_train_ds = train_ds.map(tokenize_function, batched=True)
tokenized_valid_ds = valid_ds.map(tokenize_function, batched=True)

print(tokenized_train_ds)
print(tokenized_valid_ds)

In [None]:
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)

tf_train_ds = tokenized_train_ds.to_tf_dataset(
    columns=["attention_mask", "input_ids"],
    label_cols=["score"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=config['batch_size'],
)

tf_valid_ds = tokenized_valid_ds.to_tf_dataset(
    columns=["attention_mask", "input_ids"],
    label_cols=["score"],
    shuffle=False,
    collate_fn=data_collator,
    batch_size=config['batch_size'],
)

print(len(tf_train_ds))
print(len(tf_valid_ds))

# 3. Model Training

In [None]:
from transformers import TFAutoModel

# Downloading model (Internet required)
#roberta_model = TFAutoModel.from_pretrained(checkpoint)
roberta_model = TFAutoModel.from_pretrained(config['model_path'])

### 3.1 Model

In [None]:
class MultiHeadAttentionRegressor(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.roberta_layer = roberta_model

        self.query_1 = tf.keras.layers.Dense(128, use_bias=False, activation=None)
        self.query_2 = tf.keras.layers.Dense(128, use_bias=False, activation=None)
        self.query_3 = tf.keras.layers.Dense(128, use_bias=False, activation=None)

        self.key_1 = tf.keras.layers.Dense(128, use_bias=False, activation=None)
        self.key_2 = tf.keras.layers.Dense(128, use_bias=False, activation=None)
        self.key_3 = tf.keras.layers.Dense(128, use_bias=False, activation=None)
        
        self.regressor_1 = tf.keras.models.Sequential([
            tf.keras.layers.Dense(512, activation='selu'),
            tf.keras.layers.Dropout(0.2),
        ])
        
        self.regressor_2 = tf.keras.models.Sequential([
            tf.keras.layers.Dense(512, activation='selu'),
            tf.keras.layers.Dropout(0.2)
        ])
        
        self.regressor_3 = tf.keras.models.Sequential([
            tf.keras.layers.Dense(512, activation='selu'),
            tf.keras.layers.Dropout(0.3),
            tf.keras.layers.Dense(128, activation='selu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(1, activation=None)
            ])
        
    def call(self, inputs, training=None):
        roberta_outputs = self.roberta_layer(inputs)
        
        pooler_outputs = roberta_outputs['pooler_output'] ## TensorShape([batch_num, 768])
        output_1 = self.regressor_1(pooler_outputs) ## TensorShape([batch_num, 512])
        
        attention_mask = tf.expand_dims(inputs['attention_mask'], -1) ## TensorShape([batch_num, max_len, 1])
        attention_mask = tf.cast(attention_mask, dtype=tf.float32)
        last_hidden_states = roberta_outputs['last_hidden_state'] * attention_mask ## TensorShape([batch_num, max_len, 768])

        lhs_1 = last_hidden_states[:, :, :256] ## TensorShape([batch_num(8), max_len, 256])
        lhs_2 = last_hidden_states[:, :, 256:512]
        lhs_3 = last_hidden_states[:, :, 512:]

        q_1 = self.query_1(lhs_1) ## TensorShape([8, max_len, 128])
        k_1 = tf.expand_dims(self.key_1(pooler_outputs), -1) ## TensorShape([8, 128, 1])
        a_scores_1 = tf.linalg.matmul(q_1, k_1) / tf.math.sqrt(128.) ## TensorShape([8, max_len, 1])
        a_weights_1 = tf.keras.layers.Softmax(axis=1)(a_scores_1) ## TensorShape([8, max_len, 1])
        average_hidden_states_1 = tf.math.reduce_sum(lhs_1 * a_weights_1, axis=1) ## TensorShape([8, 256])

        q_2 = self.query_1(lhs_2)
        k_2 = tf.expand_dims(self.key_2(pooler_outputs), -1)
        a_scores_2 = tf.linalg.matmul(q_2, k_2) / tf.math.sqrt(128.)
        a_weights_2 = tf.keras.layers.Softmax(axis=1)(a_scores_2)
        average_hidden_states_2 = tf.math.reduce_sum(lhs_2 * a_weights_2, axis=1)

        q_3 = self.query_3(lhs_3)
        k_3 = tf.expand_dims(self.key_3(pooler_outputs), -1)
        a_scores_3 = tf.linalg.matmul(q_3, k_3) / tf.math.sqrt(128.)
        a_weights_3 = tf.keras.layers.Softmax(axis=1)(a_scores_3)
        average_hidden_states_3 = tf.math.reduce_sum(lhs_3 * a_weights_3, axis=1)

        average_hidden_states = tf.concat([average_hidden_states_1,
                                           average_hidden_states_2,
                                           average_hidden_states_3], axis=-1) ## TensorShape([8, 768])
        output_2 = self.regressor_2(average_hidden_states) ## TensorShape([8, 512])
        
        output_3 = tf.concat([output_1, output_2], axis=-1) ## TensorShape([8, 1024])
        outputs = self.regressor_3(output_3)
        
        return outputs

model = MultiHeadAttentionRegressor()

### 3.2 Training

In [None]:
model.roberta_layer.trainable = False

num_epochs = 2
num_train_steps = len(tf_train_ds) * num_epochs

lr_scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=5e-4, end_learning_rate=5e-5, decay_steps=num_train_steps
)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_scheduler),
              loss=tf.keras.losses.MeanSquaredError()
             )

for data, label in tf_train_ds.take(1):
    example = data
result = model(example)
print(result)
model.summary()

In [None]:
fit_history = model.fit(tf_train_ds,
                        epochs=num_epochs,
                        validation_data=tf_valid_ds,
                        verbose=1)

In [None]:
model.roberta_layer.trainable = True

num_epochs = 2
num_train_steps = len(tf_train_ds) * num_epochs

lr_scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=2e-5, end_learning_rate=2e-6, decay_steps=num_train_steps
)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_scheduler),
              loss=tf.keras.losses.MeanSquaredError()
             )

#result = model(example)
#print(result)
model.summary()

In [None]:
fit_history = model.fit(tf_train_ds,
                        epochs=num_epochs,
                        validation_data=tf_valid_ds,
                        verbose=1)

# 4. Prediction & Submit

In [None]:
test_df = pd.read_csv("../input/jigsaw-toxic-severity-rating/comments_to_score.csv")
test_df['text'] = test_df['text'].apply(text_cleaning)
test_df = text_cleaning_2(test_df,'text')
test_df.head()

In [None]:
test_ds = datasets.Dataset.from_pandas(test_df)

tokenized_test_ds = test_ds.map(tokenize_function, batched=True)
tf_test_ds = tokenized_test_ds.to_tf_dataset(
    columns=["attention_mask", "input_ids"],
    shuffle=False,
    collate_fn=data_collator,
    batch_size=config['batch_size'],
)

print(len(tf_test_ds))

In [None]:
result = model.predict(tf_test_ds)
test_df['score'] = result
submission_df = test_df[['comment_id', 'score']]

submission_df.to_csv("submission.csv", index=False)
submission_df

### 4.1 validation

In [None]:
# New data for validation: text pairs.
data_valid = pd.read_csv('../input/jigsaw-toxic-severity-rating/validation_data.csv')
data_valid['less_toxic'] = data_valid['less_toxic'].apply(text_cleaning)
data_valid['more_toxic'] = data_valid['more_toxic'].apply(text_cleaning)
data_valid = text_cleaning_2(data_valid,'less_toxic')
data_valid = text_cleaning_2(data_valid,'more_toxic')
data_valid.head()

In [None]:
more_or_less_toxic_ds = datasets.Dataset.from_pandas(data_valid)
more_or_less_toxic_ds

In [None]:
def less_toxic_tokenize(example):
    return tokenizer(example['less_toxic'], truncation=True, max_length=128)

def more_toxic_tokenize(example):
    return tokenizer(example['more_toxic'], truncation=True, max_length=128)

less_toxic_ds = more_or_less_toxic_ds.map(less_toxic_tokenize, batched=True)
more_toxic_ds = more_or_less_toxic_ds.map(more_toxic_tokenize, batched=True)

print(less_toxic_ds)
print(more_toxic_ds)

In [None]:
tf_less_ds = less_toxic_ds.to_tf_dataset(
    columns=["attention_mask", "input_ids"],
    shuffle=False,
    collate_fn=data_collator,
    batch_size=config['batch_size'],
)

tf_more_ds = more_toxic_ds.to_tf_dataset(
    columns=["attention_mask", "input_ids"],
    shuffle=False,
    collate_fn=data_collator,
    batch_size=config['batch_size'],
)

print(len(tf_less_ds))
print(len(tf_more_ds))

In [None]:
less_scores = model.predict(tf_less_ds)
more_scores = model.predict(tf_more_ds)

data_valid['less_score'] = less_scores
data_valid['more_score'] = more_scores

data_valid['correct'] = 1
data_valid['correct'] = data_valid['correct'].where(data_valid['less_score'] < data_valid['more_score'], 0)

accuracy = data_valid['correct'].sum() / len(data_valid)
accuracy