In [1]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Dense , Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
from kaggle_datasets import KaggleDatasets
import transformers
from tqdm.notebook import tqdm
from tokenizers import BertWordPieceTokenizer

In [2]:
def fast_encode(texts, tokenizer, chunk_size = 256, maxlen = 512):
    tokenizer.enable_truncation(max_length = maxlen)
    tokenizer.enable_padding(max_length = maxlen)
    all_ids = []
    
    for i in tqdm(range(0,len(texts), chunk_size)):
        text_chunk = texts[i:i+chunk_size].tolist()
        encs = tokenizer.encode_batch(text_chunk)
        all_ids.extend([enc.ids for enc in encs])
        
    return np.array(all_ids)    
    

In [3]:
def build_model(transformer, max_len = 512):
    input_word_ids = Input(shape = (max_len,), dtype = tf.int32, name = 'input_word_ids')
    sequence_output = transformer(input_word_ids)[0]
    cls_token = sequence_output[:,0 ,:]
    out = Dense(1, activation='sigmoid')(cls_token)
    
    model = Model(inputs = input_word_ids, outputs = out)
    model.compile(Adam(lr = 1e-5), loss = 'binary_crossentropy', metrics = ['accuracy'])
    
    return model

In [4]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU' , tpu.master())
except ValueError : 
    tpu = None
    
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

print("Replicas", strategy.num_replicas_in_sync)

Running on TPU grpc://10.0.0.2:8470
Replicas 8


In [5]:
AUTO  = tf.data.experimental.AUTOTUNE

#GCS_DS_PATH = KaggleDatasets.get_gcs_path()

EPOCHS = 3
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
MAX_LEN = 192

In [6]:
tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-multilingual-cased')
tokenizer.save_pretrained('.')
fast_tokenizer = BertWordPieceTokenizer('vocab.txt', lowercase = False)
fast_tokenizer

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=995526.0, style=ProgressStyle(descripti…




Tokenizer(vocabulary_size=119547, model=BertWordPiece, add_special_tokens=True, unk_token=[UNK], sep_token=[SEP], cls_token=[CLS], clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=False, wordpieces_prefix=##)

In [7]:
train1 = pd.read_csv("/kaggle/input/jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv")
train2 = pd.read_csv("/kaggle/input/jigsaw-multilingual-toxic-comment-classification/jigsaw-unintended-bias-train.csv")
train2.toxic = train2.toxic.round().astype(int)
valid = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/validation.csv')
test = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/test.csv')
sub = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/sample_submission.csv')

In [8]:
train = pd.concat([train1[['comment_text', 'toxic']], train2[['comment_text', 'toxic']].query('toxic==1'), 
                   train2[['comment_text', 'toxic']].query('toxic==0').sample(n= 150000, random_state=0)])

In [9]:
x_train = fast_encode(train.comment_text.astype(str), fast_tokenizer, maxlen = MAX_LEN)
x_valid = fast_encode(valid.comment_text.astype(str), fast_tokenizer, maxlen = MAX_LEN)
x_test = fast_encode(test.content.astype(str), fast_tokenizer, maxlen = MAX_LEN)

y_train = train.toxic.values
y_valid = valid.toxic.values

HBox(children=(FloatProgress(value=0.0, max=1898.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))




In [10]:
train_dataset = (
    tf.data.Dataset.from_tensor_slices((x_train,y_train)).repeat().shuffle(2048).batch(BATCH_SIZE).prefetch(AUTO)
)

valid_dataset = (
    tf.data.Dataset.from_tensor_slices((x_valid,y_valid)).batch(BATCH_SIZE).cache().prefetch(AUTO)
)

test_dataset = (
    tf.data.Dataset.from_tensor_slices(x_test).batch(BATCH_SIZE)
)

In [11]:
%%time
with strategy.scope():
    transformer_layer = (transformers.TFDistilBertModel.from_pretrained('distilbert-base-multilingual-cased'))
    model = build_model(transformer_layer, max_len = MAX_LEN)
model.summary()    

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=618.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=910749124.0, style=ProgressStyle(descri…


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_word_ids (InputLayer)  [(None, 192)]             0         
_________________________________________________________________
tf_distil_bert_model (TFDist ((None, 192, 768),)       134734080 
_________________________________________________________________
tf_op_layer_strided_slice (T [(None, 768)]             0         
_________________________________________________________________
dense (Dense)                (None, 1)                 769       
Total params: 134,734,849
Trainable params: 134,734,849
Non-trainable params: 0
_________________________________________________________________
CPU times: user 31.7 s, sys: 10 s, total: 41.7 s
Wall time: 45.3 s


In [12]:
n_steps = x_train.shape[0] // BATCH_SIZE
train_history = model.fit(train_dataset, steps_per_epoch = n_steps, validation_data = valid_dataset, epochs = EPOCHS)

Train for 3795 steps, validate for 63 steps
Epoch 1/3
Epoch 2/3
Epoch 3/3


In [13]:
n_steps = x_valid.shape[0] // BATCH_SIZE
train_history_2 = model.fit(valid_dataset.repeat(), steps_per_epoch = n_steps, epochs = EPOCHS*2)

Train for 62 steps
Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


In [14]:
sub['toxic'] = model.predict(test_dataset, verbose = 1)
sub.to_csv('submission.csv', index = False)

