In [None]:
import pathlib

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras
from transformers import TFAutoModel, AutoTokenizer
from tqdm.auto import tqdm

Kudos to [this Kaggle kernel](https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta).

In [None]:
ROOT_PATH = pathlib.Path("/kaggle/input/jigsaw-multilingual-toxic-comment-classification/")
MODEL = "distilbert-base-multilingual-cased"
BATCH_SIZE = 32
EPOCHS = 1
MAX_DOC_LENGTH = 256

## Data

In [None]:
train_df = pd.read_csv(ROOT_PATH / "jigsaw-toxic-comment-train.csv")
valid_df = pd.read_csv(ROOT_PATH / "validation.csv")
test_df = pd.read_csv(ROOT_PATH / "test.csv").rename(columns={"content": "comment_text"})
train_df.sample(5)

In [None]:
train_df["toxic"].mean(), valid_df["toxic"].mean()

In [None]:
class Tokenizer:
    def __init__(
        self,
        tokenizer,
        max_doc_length: int,
        padding = True,
    ) -> None:
        self.tokenizer = tokenizer
        self.max_doc_length = max_doc_length
        self.padding = padding

    def __call__(self, x):
        return self.tokenizer(
            x,
            max_length=self.max_doc_length,
            truncation=True,
            padding=self.padding,
            return_tensors="tf",
        )
    
tokenizer = Tokenizer(AutoTokenizer.from_pretrained(MODEL), MAX_DOC_LENGTH)

In [None]:
def get_tokenized_values(text, tokenizer, batch_size):
    input_ids = []
    attention_mask = []
    for i in tqdm(range(0, len(text), batch_size)):
        tokenized_batch = tokenizer(text[i: i+batch_size])
        input_ids.append(tokenized_batch["input_ids"])
        attention_mask.append(tokenized_batch["attention_mask"])
        
    return tf.concat(input_ids, axis=0), tf.concat(attention_mask, axis=0)

train_input_ids, train_attention_mask = get_tokenized_values(train_df["comment_text"].values.tolist(), tokenizer, BATCH_SIZE * 4)
valid_input_ids, valid_attention_mask = get_tokenized_values(valid_df["comment_text"].values.tolist(), tokenizer, BATCH_SIZE * 4)
test_input_ids, test_attention_mask = get_tokenized_values(test_df["comment_text"].values.tolist(), tokenizer, BATCH_SIZE * 4)

y_train = train_df.toxic.values
y_valid = valid_df.toxic.values

In [None]:
train_dataset = (
    tf.data.Dataset
    .from_tensor_slices(((train_input_ids, train_attention_mask), y_train))
    .repeat()
    .shuffle(2048)
    .batch(BATCH_SIZE)
    .prefetch(BATCH_SIZE * 2)
)

valid_dataset = (
    tf.data.Dataset
    .from_tensor_slices(((valid_input_ids, valid_attention_mask), y_valid))
    .batch(BATCH_SIZE)
    .prefetch(BATCH_SIZE * 2)
)

test_dataset = (
    tf.data.Dataset
    .from_tensor_slices(((test_input_ids, test_attention_mask), np.ones(len(test_input_ids))))
    .batch(BATCH_SIZE)
    .prefetch(BATCH_SIZE * 2)
)

In [None]:
x, y = next(iter(train_dataset))
x[0].shape, x[1].shape

## Model

- Transformers: https://jalammar.github.io/illustrated-transformer/
- BERT: https://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/

In [None]:
bert_model = TFAutoModel.from_pretrained(MODEL)

input_ids = keras.layers.Input(shape=(MAX_DOC_LENGTH,), dtype=tf.int32)
attention_mask = keras.layers.Input(shape=(MAX_DOC_LENGTH,), dtype=tf.int32)
sequence_output = bert_model(input_ids, attention_mask)[0]
cls_token = sequence_output[:, 0, :]
out = keras.layers.Dense(1, activation="sigmoid")(cls_token)

model = keras.models.Model(inputs=(input_ids, attention_mask), outputs=out)
model.compile(keras.optimizers.Adam(lr=1e-5), loss="binary_crossentropy", metrics=["accuracy", keras.metrics.AUC()])

In [None]:
model.summary()

## Training
Also worth trying: https://datascience.stackexchange.com/a/13496/32796

In [None]:
n_steps = train_input_ids.shape[0] // BATCH_SIZE
train_history = model.fit(
    train_dataset,
    steps_per_epoch=n_steps,
    validation_data=valid_dataset,
    epochs=EPOCHS
)

In [None]:
model.evaluate(valid_dataset)

## Submission

In [None]:
sub = pd.read_csv(ROOT_PATH / "sample_submission.csv")
sub['toxic'] = model.predict(test_dataset, verbose=1)
sub.to_csv('submission.csv', index=False)