# Finetuning RemBERT on Tensorflow TPU

### RemBERT is a pain to finetune on Kaggle GPUs, but simple on TPU! Hopefully this can help people in the future.

I made my own callback that does validation every few batches, saving when it goes over a certain level.  
Note:
- I'll try to update this to include Hugging Face datasets
- If you do more than 8 folds, it will likely run out of space in the output directory

### A huge shoutout to this notebook from which I modified heavily: https://www.kaggle.com/msafi04/tensorflow-hf-qa-using-externaldata-tpu

In [None]:
!pip install -U --no-build-isolation --no-deps ../input/transformers-master/ -qq # necessary for mixed precision roberta and rembert

In [None]:
import os
import gc
import json
import collections
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras.backend as K

import transformers
from transformers import AutoTokenizer, TFAutoModel, AutoConfig

warnings.simplefilter(action='ignore', category=Warning)
%env TOKENIZERS_PARALLELISM=true

In [None]:
CFG = {
    'EPOCHS': 2,
    "MODEL": "../input/rembert-tf",
    'N_FOLDS': 5,
    'SEED': 777,
    'VERBOSE': 1,
    'BATCH_SIZE': 32,
    'MAX_LENGTH': 384,
    'DOC_STRIDE': 128,
    'VALIDATE_EVERY': 0.3, # fraction of epoch between validations
    'VAL_START_BATCH': 0.5, # don't do validation on first epoch until after this fraction of batches has been passed
    'LANG_FOCUS': None, # set to hindi or tamil to focus on those scores. None does overall score
    'MIN_SCORE_TO_SAVE': 0.6, # don't save unless jaccard score above this value
    "LR": 3e-5,
}

In [None]:
# https://www.kaggle.com/hidehisaarai1213/g2net-tf-on-the-fly-cqt-tpu-training?scriptVersionId=71575767&cellId=11
def auto_select_accelerator():
    TPU_DETECTED = False
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
        TPU_DETECTED = True
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")

    return strategy, TPU_DETECTED

strategy, tpu_detected = auto_select_accelerator()
AUTO     = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
# tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

print(f'REPLICAS: {REPLICAS}')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CFG["MODEL"])
pad_on_right = tokenizer.padding_side == "right"

In [None]:
train = pd.read_csv('../input/chaiimlqaxquad/chaii-mlqa-xquad-5folds.csv')
external_hi = train[train["src"]!='chaii'].sample(frac=1, random_state=2021).reset_index(drop=True)
external_te = pd.read_csv("../input/all-mlqa-xquad-tydiqa/tydiqa.csv")
external_te = external_te[external_te["language"].str.contains("telugu")][["answer_start", "answer_text", "context", "question"]]
external_te = external_te.sample(frac=1, random_state=2021).reset_index(drop=True)
external_te["source"] = "tydiqa"
external_te["fold"] = -1
external = pd.concat([external_hi, external_te], axis=0, ignore_index=True)
train = train[train["src"]=='chaii'].sample(frac=1, random_state=2021).reset_index(drop=True)

In [None]:
# Split data to folds
n_folds = CFG['N_FOLDS']
train['kfold'] = -1
external["kfold"] = -1

train["id"] = [f"chaii{i}" for i in range(len(train))]
external["id"] = [f"external{i}" for i in range(len(external))]


skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=CFG["SEED"])
for fold, (trn_idx, val_idx) in enumerate(skf.split(X=train, y=train['language'].values)):
    train.loc[val_idx, 'kfold'] = fold
train.head(2)

In [None]:
def prepare_training(examples):
    examples['question'] = [q.lstrip() for q in examples['question']] #remove leading white space
    
    #Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    #in one example possible giving several features when a context is long, each of those features having a
    #context that overlaps a bit the context of the previous feature.
    
    tokenized_examples = tokenizer(
                list(examples['question' if pad_on_right else 'context'].values),
                list(examples['context' if pad_on_right else 'question'].values),
                truncation='only_second' if pad_on_right else 'only_first',
                max_length=CFG['MAX_LENGTH'],
                stride=CFG['DOC_STRIDE'],
                return_overflowing_tokens=True,
                return_offsets_mapping=True,
                padding='max_length'
            )
    #Since one example might give us several features if it has a long context, we need a map from a feature to
    #its corresponding example. This key gives us just that.
    
    sample_mapping = tokenized_examples.pop('overflow_to_sample_mapping')
    
    #The offset mappings will give us a map from token to character position in the original context. This will
    #help us compute the start_positions and end_positions.
    
    offset_mapping = tokenized_examples.pop('offset_mapping')
    
    tokenized_examples['start_positions'] = []
    tokenized_examples['end_positions'] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples['input_ids'][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        
        sequence_ids = tokenized_examples.sequence_ids(i)
        
        #One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples.loc[sample_index, 'answer_text']
        start_char = examples.loc[sample_index, 'answer_start']
        
        # If no answers are given, set the cls_index as answer.
        if start_char is None:
            tokenized_examples['start_positions'].append(cls_index)
            tokenized_examples['end_positions'].append(cls_index)
        else:
            # Start/end character idx of the answer in the text.
            end_char = start_char + len(answers)
            
             #Start token idx of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1
            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1
            #Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples['start_positions'].append(cls_index)
                tokenized_examples['end_positions'].append(cls_index)
            else:
                #Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                #Note: we could go after the last offset if the answer is the last word (edge case).
                
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples['start_positions'].append(token_start_index - 1)
                
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples['end_positions'].append(token_end_index + 1)

    return tokenized_examples

In [None]:
def prepare_validation(examples):
    examples['question'] = [q.lstrip() for q in examples['question']]
    
    tokenized_examples = tokenizer(
                list(examples['question' if pad_on_right else 'context'].values),
                list(examples['context' if pad_on_right else 'question'].values),
                truncation = 'only_second' if pad_on_right else 'only_first',
                max_length = CFG['MAX_LENGTH'],
                stride = CFG['DOC_STRIDE'],
                return_overflowing_tokens = True,
                return_offsets_mapping = True,
                padding = 'max_length'
            )
    
    sample_mapping = tokenized_examples.pop('overflow_to_sample_mapping')
    
    #id column from the dataset
    tokenized_examples['example_id'] = []

    for i in range(len(tokenized_examples['input_ids'])):
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0
        sample_index = sample_mapping[i]
        tokenized_examples['example_id'].append(examples.loc[sample_index, 'id'])
        tokenized_examples['offset_mapping'][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples['offset_mapping'][i])
        ]

    return tokenized_examples

In [None]:
def build_tf_dataset(df, batch_size=4, flag='train'):
    
    if flag == 'train':
        features = prepare_training(df)
    else:
        features = prepare_validation(df)
    
    input_ids = features['input_ids']
    attn_masks = features['attention_mask']
    
    if flag == 'train':
        
        # This enables label smoothing
        start_positions = np.zeros((len(features['start_positions']), CFG['MAX_LENGTH']))
        for row, pos in enumerate(features['start_positions']):
            start_positions[row, pos] = 1
    
        end_positions = np.zeros((len(features['end_positions']), CFG['MAX_LENGTH']))
        for row, pos in enumerate(features['end_positions']):
            end_positions[row, pos] = 1
        
        train_dataset = tf.data.Dataset.from_tensor_slices((input_ids, attn_masks, start_positions, end_positions))
        train_dataset = train_dataset.map(lambda x1, x2, y1, y2: ({'input_ids': x1, 'attention_mask': x2}, {'start_positions': y1, 'end_positions': y2}))
        train_dataset = train_dataset.batch(batch_size)
        train_dataset = train_dataset.shuffle(1000)
        train_dataset = train_dataset.prefetch(AUTO)
        
        return train_dataset, features
    
    elif flag == 'valid':
        dataset = tf.data.Dataset.from_tensor_slices((input_ids, attn_masks))
        dataset = dataset.map(lambda x1, x2: ({'input_ids': x1, 'attention_mask': x2}))
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(buffer_size=AUTO)
        
        return dataset, features

In [None]:
def build_model(num_steps):
    roberta = TFAutoModel.from_pretrained(CFG['MODEL'])
    
    input_ids = tf.keras.layers.Input(shape=(CFG["MAX_LENGTH"], ), name='input_ids', dtype=tf.int32)
    attention_mask=tf.keras.layers.Input(shape=(CFG["MAX_LENGTH"], ), name='attention_mask', dtype=tf.int32)
    
    embeddings = roberta(input_ids=input_ids, attention_mask=attention_mask)[0]
    
    x1 = tf.keras.layers.Dropout(0.1)(embeddings) 
    x1 = tf.keras.layers.Dense(1, dtype=tf.float32)(x1)
    x1 = tf.keras.layers.Flatten()(x1)
    x1 = tf.keras.layers.Activation('softmax', name='start_positions', dtype=tf.float32)(x1)
    
    x2 = tf.keras.layers.Dropout(0.1)(embeddings) 
    x2 = tf.keras.layers.Dense(1, dtype=tf.float32)(x2)
    x2 = tf.keras.layers.Flatten()(x2)
    x2 = tf.keras.layers.Activation('softmax', name='end_positions', dtype=tf.float32)(x2)

    model = tf.keras.models.Model(inputs=[input_ids, attention_mask], outputs=[x1, x2])
    
    # linear decay
    sched = tf.keras.optimizers.schedules.PolynomialDecay(
            CFG["LR"], num_steps, end_learning_rate=1e-6, power=1.0,
            cycle=False, name=None
        )
    optimizer = tf.keras.optimizers.Adam(learning_rate=sched)
#     optimizer = tfa.optimizers.AdamW(learning_rate=clr, weight_decay=0.01)  # adamw does very poorly
#     optimizer = tfa.optimizers.SWA(optimizer) # https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/SWA
    loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False, label_smoothing=0.1)
    
    model.compile(loss=[loss, loss], optimizer=optimizer)

    return model

In [None]:
def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

In [None]:
def post_process_predictions(examples, features, start, end, n_best_size=20, max_answer_length=30):
    
    all_start_logits, all_end_logits = start, end
    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples['id'])}
    features_per_example = collections.defaultdict(list)
    
    for i, feature in enumerate(features['example_id']):
        features_per_example[example_id_to_index[feature]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f"Post-processing {len(examples)} example predictions split into {len(features['input_ids'])} features.")

    # Let's loop over all the examples!
    for example_index, example in examples.iterrows():
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]
        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []
        
        context = example['context']
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features['offset_mapping'][feature_index]

            # Update minimum null prediction.
            cls_index = features['input_ids'][feature_index].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}
        
        # Let's pick our final answer: the best one or the null answer (only for squad_v2)
        #if not squad_v2:
        #    predictions[example["id"]] = best_answer["text"]
        #else:
        answer = best_answer["text"] 
        predictions[example['id']] = answer

    return predictions

In [None]:
class JaccardVal(tf.keras.callbacks.Callback):
    """Learning rate scheduler which sets the learning rate according to schedule.

  Arguments:
      schedule: a function that takes an epoch index
          (integer, indexed from 0) and current learning rate
          as inputs and returns a new learning rate as output (float).
  """

    def __init__(self, valid_df, valid_dataset, valid_enc, fold, epoch_steps, best_scores):
        super().__init__()
        self.valid_df = valid_df
        self.valid_dataset = valid_dataset
        self.valid_enc = valid_enc
        self.epoch = 0
        self.fold = fold
        self.validate_steps = int(epoch_steps*CFG["VALIDATE_EVERY"])
        self.val_start_batch = epoch_steps*CFG["VAL_START_BATCH"]
        self.best_scores = best_scores
        
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch # starts counting at 0, printed epoch adds 1
        
    def on_batch_end(self, batch, logs=None):
        if self.epoch==0 and batch < self.val_start_batch:
            # don't validate in first epoch until passed val_start_batch
            pass
        elif batch % self.validate_steps == 0:
            start_pred, end_pred = self.model.predict(self.valid_dataset, batch_size=CFG["BATCH_SIZE"], verbose=1)
        
            valid_preds = post_process_predictions(self.valid_df, self.valid_enc, start_pred, end_pred)

            valid_df["pred"] = valid_df["id"].map(valid_preds)

            valid_df["scores"] = [jaccard(str1, str2) for str1, str2 in valid_df[["pred", "answer_text"]].values]

            valid_df[["id", "pred"]].to_csv(f"predictions_fold{self.fold}.csv", index=False)

            j_score = np.mean(valid_df["scores"])
            print(f'Jaccard Score after epoch {self.epoch}, batch {batch}: {j_score}')
            hi_score = np.mean(valid_df[valid_df["language"]=="hindi"]["scores"])
            print(f'HINDI - Jaccard Score after epoch {self.epoch}, batch {batch}: {hi_score}')
            ta_score = np.mean(valid_df[valid_df["language"]=="tamil"]["scores"])
            print(f'TAMIL - Jaccard Score after epoch {self.epoch}, batch {batch}: {ta_score}')

            if CFG['LANG_FOCUS'] == "hindi":
                
                if hi_score > self.best_scores[self.fold]:
                    self.best_scores[self.fold] = hi_score
                    self.model.save_weights(f"fold{self.fold}"+"/tf_model.h5")
                    print(f"New best Hindi Jaccard Score in epoch {self.epoch}, batch {batch}. Saving model.")    
            
            elif CFG['LANG_FOCUS'] == "tamil":
                
                if ta_score > self.best_scores[self.fold]:
                    self.best_scores[self.fold] = ta_score
                    self.model.save_weights(f"fold{self.fold}"+"/tf_model.h5")
                    print(f"New best Tamil Jaccard Score in epoch {self.epoch}, batch {batch}. Saving model.") 
            else:
                if j_score > self.best_scores[self.fold]:
                    self.best_scores[self.fold] = j_score
                    self.model.save_weights(f"fold{self.fold}"+"/tf_model.h5")
                    print(f"New best Overall Jaccard Score in epoch {self.epoch}, batch {batch}. Saving model.") 

In [None]:
best_scores = {f:CFG['MIN_SCORE_TO_SAVE'] for f in range(CFG['N_FOLDS'])}

for fold in range(CFG['N_FOLDS']):
    output_dir = f"fold{fold}"
    %mkdir $output_dir
    print('#########' * 15)
    print(f"Fold: {fold}")
    print('#########' * 15)

    train_df = train[train['kfold'] != fold]
    valid_df = train[train['kfold'] == fold]

    train_df = pd.concat([train_df, external], axis=0, ignore_index=True)

    valid_df = valid_df.reset_index(drop=True)
    print(train_df.shape, valid_df.shape)

    K.clear_session()

    strategy, tpu_detected = auto_select_accelerator()
    
    train_dataset, train_enc = build_tf_dataset(train_df, batch_size=CFG["BATCH_SIZE"], flag='train')
    valid_dataset, valid_enc = build_tf_dataset(valid_df, batch_size=CFG["BATCH_SIZE"], flag='valid')
    
    epoch_steps = len(train_dataset)
    total_steps = epoch_steps*CFG['EPOCHS']

    with strategy.scope():
        model = build_model(total_steps)
        model.load_weights("../input/remb-pretrain-tpu-tf/tf_model.h5")
    
    if fold==0:
        print(model.summary())


    
    

    history = model.fit(train_dataset, 
                        epochs=CFG['EPOCHS'], 
                        batch_size=CFG["BATCH_SIZE"],
                        callbacks=[
                            JaccardVal(valid_df, valid_dataset, valid_enc, fold, epoch_steps, best_scores)
                        ],
                        verbose=1
                        )
    
    del model
    del train_dataset
    del train_enc
    del valid_dataset
    del valid_enc
    del train_df
    del valid_df
    gc.collect();
    
focus = CFG["LANG_FOCUS"]
if focus not in {"hindi", "tamil"}:
    focus = "Overall"

for fold, score in best_scores.items():
    if score - CFG["MIN_SCORE_TO_SAVE"] < 1e-5: # when the best score wasn't higher than the minimum score to save
        score = f'<{CFG["MIN_SCORE_TO_SAVE"]}'
    else:
        score = round(score, 3)
    print(f"Best {focus} score - fold {fold}: {score}")