# RoBERTa using PyTorch
## This a RoBERTa version of @abhishek's [BERT Base Uncased using PyTorch](https://www.kaggle.com/abhishek/bert-base-uncased-using-pytorch)

# All the important imports

In [None]:
VERSION = 20200515
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION --apt-packages libomp5 libopenblas-dev

In [None]:
!export XLA_USE_BF16=1

In [None]:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

In [None]:
import os
import torch
import pandas as pd
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.optim import lr_scheduler

from sklearn import model_selection
from sklearn import metrics
import transformers
import tokenizers
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup,get_cosine_with_hard_restarts_schedule_with_warmup
from tqdm.autonotebook import tqdm
import utils

In [None]:
ROBERTA_PATH = "../input/roberta-base"
TOKENIZER = tokenizers.ByteLevelBPETokenizer(
    vocab_file=f"{ROBERTA_PATH}/vocab.json", 
    merges_file=f"{ROBERTA_PATH}/merges.txt", 
    lowercase=True,
    add_prefix_space=True
)
TRAINING_FILE = "../input/tweet-train-folds-v2/train_folds.csv"
MAX_LEN = 192
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 8
EPOCHS = 5

# Data Processing

In [None]:
def process_data(tweet, selected_text, sentiment, tokenizer, max_len):
    """
    Processes the tweet and outputs the features necessary for model training and inference.
    
    Note: there are some differences between this and the BERT version (bert-case-uncased), mostly due to differences in token codes and special tokens
    """
    tweet = " " + " ".join(str(tweet).split())
    selected_text = " " + " ".join(str(selected_text).split())

    len_st = len(selected_text) - 1
    idx0 = None
    idx1 = None

    for ind in (i for i, e in enumerate(tweet) if e == selected_text[1]):
        if " " + tweet[ind: ind+len_st] == selected_text:
            idx0 = ind
            idx1 = ind + len_st - 1
            break

    char_targets = [0] * len(tweet)
    if idx0 != None and idx1 != None:
        for ct in range(idx0, idx1 + 1):
            char_targets[ct] = 1
    
    tok_tweet = tokenizer.encode(tweet)
    input_ids_orig = tok_tweet.ids
    tweet_offsets = tok_tweet.offsets
    
    target_idx = []
    for j, (offset1, offset2) in enumerate(tweet_offsets):
        if sum(char_targets[offset1: offset2]) > 0:
            target_idx.append(j)
    
    targets_start = target_idx[0]
    targets_end = target_idx[-1]

    sentiment_id = {
        'positive': 1313,
        'negative': 2430,
        'neutral': 7974
    }
    
    input_ids = [0] + [sentiment_id[sentiment]] + [2] + [2] + input_ids_orig + [2]
    token_type_ids = [0, 0, 0, 0] + [0] * (len(input_ids_orig) + 1)
    mask = [1] * len(token_type_ids)
    tweet_offsets = [(0, 0)] * 4 + tweet_offsets + [(0, 0)]
    targets_start += 4
    targets_end += 4

    padding_length = max_len - len(input_ids)
    if padding_length > 0:
        input_ids = input_ids + ([1] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        tweet_offsets = tweet_offsets + ([(0, 0)] * padding_length)
    
    return {
        'ids': input_ids,
        'mask': mask,
        'token_type_ids': token_type_ids,
        'targets_start': targets_start,
        'targets_end': targets_end,
        'orig_tweet': tweet,
        'orig_selected': selected_text,
        'sentiment': sentiment,
        'offsets': tweet_offsets
    }

# Data loader

In [None]:
class TweetDataset:
    """
    Dataset which stores the tweets and returns them as processed features
    """
    def __init__(self, tweet, sentiment, selected_text):
        self.tweet = tweet
        self.sentiment = sentiment
        self.selected_text = selected_text
        self.tokenizer = TOKENIZER
        self.max_len = MAX_LEN
    
    def __len__(self):
        return len(self.tweet)

    def __getitem__(self, item):
        data = process_data(
            self.tweet[item], 
            self.selected_text[item], 
            self.sentiment[item],
            self.tokenizer,
            self.max_len
        )
        targets_start_oh = torch.zeros(192)
        targets_start_oh[data["targets_start"]] = 1
        
        targets_end_oh = torch.zeros(192)
        targets_end_oh[data["targets_end"]] = 1
        
        # Return the processed data where the lists are converted to `torch.tensor`s
        return {
            'ids': torch.tensor(data["ids"], dtype=torch.long),
            'mask': torch.tensor(data["mask"], dtype=torch.long),
            'token_type_ids': torch.tensor(data["token_type_ids"], dtype=torch.long),
            'targets_start': torch.tensor(data["targets_start"], dtype=torch.long),
            'targets_end': torch.tensor(data["targets_end"], dtype=torch.long),
            'orig_tweet': data["orig_tweet"],
            'orig_selected': data["orig_selected"],
            'sentiment': data["sentiment"],
            'offsets': torch.tensor(data["offsets"], dtype=torch.long),
            'targets_start_oh': targets_start_oh,
            'targets_end_oh':targets_end_oh
        }

# The Model

In [None]:
class TweetModel(transformers.BertPreTrainedModel):
    """
    Model class that combines a pretrained bert model with a linear later
    """
    def __init__(self, conf):
        super(TweetModel, self).__init__(conf)
        # Load the pretrained BERT model
        self.roberta = transformers.RobertaModel.from_pretrained(ROBERTA_PATH, config=conf)
        # Set 10% dropout to be applied to the BERT backbone's output
        self.dropouts = nn.ModuleList([nn.Dropout(0.1) for _ in range(5)])
        # 768 is the dimensionality of bert-base-uncased's hidden representations
        # Multiplied by 2 since the forward pass concatenates the last two hidden representation layers
        # The output will have two dimensions ("start_logits", and "end_logits")
        self.l0 = nn.Linear(768 * 2, 2)
        torch.nn.init.normal_(self.l0.weight, std=0.02)
    
    def forward(self, ids, mask, token_type_ids):
        # Return the hidden states from the BERT backbone
        _, _, out = self.roberta(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids
        ) # bert_layers x bs x SL x (768 * 2)

        # Concatenate the last two hidden states
        # This is done since experiments have shown that just getting the last layer
        # gives out vectors that may be too taylored to the original BERT training objectives (MLM + NSP)
        # Sample explanation: https://bert-as-service.readthedocs.io/en/latest/section/faq.html#why-not-the-last-hidden-layer-why-second-to-last
        out = torch.cat((out[-1], out[-2]), dim=-1) # bs x SL x (768 * 2)
        # Apply 10% dropout to the last 2 hidden states
        for i,dropout in enumerate(self.dropouts):
            if i == 0:
                out_sum = dropout(out)
            else:
                out_sum += dropout(out)
        out = out_sum/len(self.dropouts)
        # The "dropped out" hidden vectors are now fed into the linear layer to output two scores
        logits = self.l0(out) # bs x SL x 2

        # Splits the tensor into start_logits and end_logits
        # (bs x SL x 2) -> (bs x SL x 1), (bs x SL x 1)
        start_logits, end_logits = logits.split(1, dim=-1)

        start_logits = start_logits.squeeze(-1) # (bs x SL)
        end_logits = end_logits.squeeze(-1) # (bs x SL)

        return start_logits, end_logits

# Loss Function

In [None]:
def loss_fn(start_logits, end_logits, start_positions, end_positions):
    """
    Return the sum of the cross entropy losses for both the start and end logits
    """
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    total_loss = (start_loss + end_loss)
    return total_loss

In [None]:
def jac_transform(start_logits,end_logits,softmax=True,T=1):
    if softmax:
        start_logits = torch.softmax(start_logits/T,dim=1)
        end_logits = torch.softmax(end_logits/T,dim=1)
    start_cum = torch.cumsum(start_logits,dim=1)
    end_cum = torch.cumsum(end_logits,dim=1)
    end = end_cum - end_logits
    return F.relu(start_cum - end)

In [None]:
def dice_loss(pred, target):
    """This definition generalize to real valued pred and target vector.
This should be differentiable.
    pred: tensor with first dimension as batch
    target: tensor with first dimension as batch
    """

    smooth = 1.

    # have to use contiguous since they may from a torch.view op
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(iflat * tflat)
    B_sum = torch.sum(tflat * tflat)
    
    return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )

In [None]:
def jaccard_loss(output_start,output_end,target_start,target_end, smooth=1e-10):
    start_cum = torch.cumsum(output_start,dim=1)
    end_cum = torch.cumsum(output_end,dim=1)
    end = end_cum - output_end
    pred = F.relu(start_cum-end)
    
    start_cum_t = torch.cumsum(target_start,dim=1)
    end_cum_t = torch.cumsum(target_end,dim=1)
    end_t = end_cum_t - target_end
    target = F.relu(start_cum_t-end_t)
    
    I = (pred * target).sum(axis=1, keepdim=True)
    P = pred.sum(axis=1, keepdim=True)
    T = target.sum(axis=1, keepdim=True)
    U = P + T - I 
    IOU = (I+smooth) / (U + smooth)
    
    a = torch.max(start_cum,start_cum_t)
    b = torch.min(end_cum,end_cum_t)
    b = F.relu(b - output_end - target_end)
    Ac = F.relu((a - b).sum(axis=1,keepdim=True))
    giou = IOU - (Ac - U+smooth)/(Ac+smooth)
    
    loss_giou = 1 - giou
    
    return loss_giou.mean()

In [None]:
def reduce_fn(vals):
    return sum(vals) / len(vals)

# Training Function

In [None]:
def train_fn(data_loader, model, optimizer, device,epoch,num_batches,scheduler=None):
    """
    Trains the bert model on the twitter data
    """
    # Set model to training mode (dropout + sampled batch norm is activated)
    model.train()

    # Set tqdm to add loading screen and set the length
    tk0 = tqdm(data_loader, total=num_batches, desc="Training", disable=not xm.is_master_ordinal())
    # Train the model on each batch
    for bi, d in enumerate(tk0):

        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        sentiment = d["sentiment"]
        orig_selected = d["orig_selected"]
        orig_tweet = d["orig_tweet"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        offsets = d["offsets"]
        targets_start_oh = d['targets_start_oh']
        targets_end_oh = d["targets_end_oh"]

        # Move ids, masks, and targets to gpu while setting as torch.long
        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets_start = targets_start.to(device, dtype=torch.long)
        targets_end = targets_end.to(device, dtype=torch.long)
        targets_start_oh = targets_start_oh.to(device,dtype=torch.float32)
        targets_end_oh = targets_end_oh.to(device,dtype=torch.float32)
        # Reset gradients
        model.zero_grad()
        # Use ids, masks, and token types as input to the model
        # Predict logits for each of the input tokens for each batch
        outputs_start, outputs_end = model(
            ids=ids,
            mask=mask,
            token_type_ids=token_type_ids,
        ) # (bs x SL), (bs x SL)
        # Calculate batch loss based on CrossEntropy
#         loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
        loss = loss_fn(outputs_start,outputs_end,targets_start,targets_end)
        outputs_start = torch.softmax(outputs_start,dim=1)
        outputs_end = torch.softmax(outputs_end,dim=1)
        loss2 = jaccard_loss(outputs_start,outputs_end,targets_start_oh,targets_end_oh)
        total_loss = 0.2*loss + loss2
        total_loss.backward()
        # Adjust weights based on calculated gradients
#         optimizer.step()
        xm.optimizer_step(optimizer)
#         # Update scheduler
        scheduler.step()
        
        # Update the jaccard score and loss
        # For details, refer to `AverageMeter` in https://www.kaggle.com/abhishek/utils
        print_loss = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
        print_loss2 = xm.mesh_reduce('loss_reduce', loss2, reduce_fn)
        # Print the average loss and jaccard score at the end of each batch
        tk0.set_postfix(loss=print_loss.item(),loss2=print_loss2.item())

# Evaluation Functions

In [None]:
def calculate_jaccard_score(
    original_tweet, 
    target_string, 
    sentiment_val, 
    idx_start, 
    idx_end, 
    offsets,
    verbose=False):
    """
    Calculate the jaccard score from the predicted span and the actual span for a batch of tweets
    """
    
    # A span's start index has to be greater than or equal to the end index
    # If this doesn't hold, the start index is set to equal the end index (the span is a single token)
    if idx_end < idx_start:
        idx_end = idx_start
    
    # Combine into a string the tokens that belong to the predicted span
    filtered_output  = ""
    for ix in range(idx_start, idx_end + 1):
        filtered_output += original_tweet[offsets[ix][0]: offsets[ix][1]]
        # If the token is not the last token in the tweet, and the ending offset of the current token is less
        # than the beginning offset of the following token, add a space.
        # Basically, add a space when the next token (word piece) corresponds to a new word
        if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]:
            filtered_output += " "

    # Set the predicted output as the original tweet when the tweet's sentiment is "neutral", or the tweet only contains one word
    if len(original_tweet.split()) < 2:
        filtered_output = original_tweet

    # Calculate the jaccard score between the predicted span, and the actual span
    # The IOU (intersection over union) approach is detailed in the utils module's `jaccard` function:
    # https://www.kaggle.com/abhishek/utils
    jac = utils.jaccard(target_string.strip(), filtered_output.strip())
    return jac, filtered_output


def eval_fn(data_loader, model, device,num_batches,out_word=None):
    """
    Evaluation function to predict on the test set
    """
    # Set model to evaluation mode
    # I.e., turn off dropout and set batchnorm to use overall mean and variance (from training), rather than batch level mean and variance
    # Reference: https://github.com/pytorch/pytorch/issues/5406
    model.eval()
    losses = utils.AverageMeter()
    jaccards = utils.AverageMeter()
    outputs = []
    # Turns off gradient calculations (https://datascience.stackexchange.com/questions/32651/what-is-the-use-of-torch-no-grad-in-pytorch)
    with torch.no_grad():
        tk0 = tqdm(data_loader, total=num_batches, desc="evalutation", disable=not xm.is_master_ordinal())
        # Make predictions and calculate loss / jaccard score for each batch
        for bi, d in enumerate(tk0):
            ids = d["ids"]
            token_type_ids = d["token_type_ids"]
            mask = d["mask"]
            sentiment = d["sentiment"]
            orig_selected = d["orig_selected"]
            orig_tweet = d["orig_tweet"]
            targets_start = d["targets_start"]
            targets_end = d["targets_end"]
            offsets = d["offsets"].cpu().numpy()
            
            # Move ids, masks, and targets to gpu while setting as torch.long
            ids = ids.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets_start = targets_start.to(device, dtype=torch.long)
            targets_end = targets_end.to(device, dtype=torch.long)

            # Predict logits for start and end indexes
            outputs_start, outputs_end = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
            )
            # Calculate loss for the batch
            loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
            # Apply softmax to the predicted logits for the start and end indexes
            # This converts the "logits" to "probability-like" scores
            outputs_start = torch.softmax(outputs_start, dim=1).cpu().detach().numpy()
            outputs_end = torch.softmax(outputs_end, dim=1).cpu().detach().numpy()
            # Calculate jaccard scores for each tweet in the batch
            jaccard_scores = []
            for px, tweet in enumerate(orig_tweet):
                selected_tweet = orig_selected[px]
                tweet_sentiment = sentiment[px]
                jaccard_score, output = calculate_jaccard_score(
                    original_tweet=tweet,
                    target_string=selected_tweet,
                    sentiment_val=tweet_sentiment,
                    idx_start=np.argmax(outputs_start[px, :]),
                    idx_end=np.argmax(outputs_end[px, :]),
                    offsets=offsets[px]
                )
                jaccard_scores.append(jaccard_score)
                if out_word != None:
                    outputs.append(output)
            # Update running jaccard score and loss
            if out_word == None:
                print_loss = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
            else:
                print_loss = 'none'
            # Print the average loss and jaccard score at the end of each batch
            
            tk0.set_postfix(loss=print_loss)
            jaccards.update(np.mean(jaccard_scores), ids.size(0))
            losses.update(loss.item(), ids.size(0))
    if out_word != None:
        return outputs
    return jaccards.avg

# Training

In [None]:
# model_config = transformers.RobertaConfig.from_pretrained(ROBERTA_PATH)
# # Output hidden states
# # This is important to set since we want to concatenate the hidden states from the last 2 BERT layers
# model_config.output_hidden_states = True
# # Instantiate our model with `model_config`

# MX = TweetModel(conf=model_config)
# MX

In [None]:
# MX.roberta.encoder.layer[0]

In [None]:
import torch_xla.distributed.parallel_loader as pl
# Load pretrained RoBERTa
model_config = transformers.RobertaConfig.from_pretrained(ROBERTA_PATH)
# Output hidden states
# This is important to set since we want to concatenate the hidden states from the last 2 BERT layers
model_config.output_hidden_states = True
# Instantiate our model with `model_config`

MX = TweetModel(conf=model_config)
from torch.optim import SGD
# Read training csv

def run(fold):
    """
    Train model for a speciied fold
    """
    dfx = pd.read_csv(TRAINING_FILE)
    # Set train validation set split
    df_train = dfx[dfx.kfold != fold].reset_index(drop=True)
    df_valid = dfx[dfx.kfold == fold].reset_index(drop=True)
    device = xm.xla_device()
    model = MX.to(device)
    # Instantiate TweetDataset with training data
    train_dataset = TweetDataset(
        tweet=df_train.text.values,
        sentiment=df_train.sentiment.values,
        selected_text=df_train.selected_text.values
    )
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True
    )
    
    # Instantiate DataLoader with `train_dataset`
    # This is a generator that yields the dataset in batches
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        num_workers=2,
        sampler= train_sampler,
        drop_last=True
    )

    # Instantiate TweetDataset with validation data
    valid_dataset = TweetDataset(
        tweet=df_valid.text.values,
        sentiment=df_valid.sentiment.values,
        selected_text=df_valid.selected_text.values
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
      valid_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=False
    )

    # Instantiate DataLoader with `valid_dataset`
    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=VALID_BATCH_SIZE,
        num_workers=2,
        sampler= valid_sampler
    )

    # Calculate the number of training steps
    num_train_steps = int(
        len(df_train) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS
    )
    # Get the list of named parameters
    param_optimizer = list(model.named_parameters())
    # Specify parameters where weight decay shouldn't be applied
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    # Define two sets of parameters: those with weight decay, and those without
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    ]
    # Instantiate AdamW optimizer with our two sets of parameters, and a learning rate of 3e-5
    optimizer = AdamW(
        optimizer_parameters, 
        lr=0.2*3e-5 * xm.xrt_world_size()
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=85,
        num_training_steps=num_train_steps
    )

    # Apply early stopping with patience of 2
    # This means to stop training new epochs when 2 rounds have passed without any improvement
    es = utils.EarlyStopping(patience=2, mode="max")
    xm.master_print(f"Training is Starting for fold={fold}")
    num_batches = int(len(df_train) / (TRAIN_BATCH_SIZE * xm.xrt_world_size()))
    # I'm training only for 3 epochs even though I specified 5!!!
    best_jac = 0
    for epoch in range(EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_fn(para_loader.per_device_loader(device), model, optimizer,device, epoch,num_batches,scheduler=scheduler)
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        jac = eval_fn(
            para_loader.per_device_loader(device), 
            model, 
            device,
            num_batches
        )
        jac = xm.mesh_reduce('jac_reduce', jac, reduce_fn)
        if jac>best_jac:
            xm.save(model.state_dict(),f'model_{fold}.bin')
        xm.master_print(f"Jaccard Score = {jac}")
    return jac

In [None]:
def _mp_fn(rank, flags,fold):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = run(fold)

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,0), nprocs=8, start_method='fork')

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,1), nprocs=8, start_method='fork')

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,2), nprocs=8, start_method='fork')

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,3), nprocs=8, start_method='fork')

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,4), nprocs=8, start_method='fork')

In [None]:
def final_evaluate_fn(fold):
    device = xm.xla_device()
    model = MX.to(device)
    model.load_state_dict(torch.load(f'model_{fold}.bin'))
    dfx = pd.read_csv(TRAINING_FILE)
    df_valid = dfx[dfx.kfold == fold].reset_index(drop=True)
    valid_dataset = TweetDataset(
        tweet=df_valid.text.values,
        sentiment=df_valid.sentiment.values,
        selected_text=df_valid.selected_text.values
    )
    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=VALID_BATCH_SIZE,
        num_workers=1,
    )
    num_batches = len(valid_data_loader)
    output = eval_fn(
        valid_data_loader, 
        model, 
        device,
        num_batches,
        out_word=True
    )
    df_valid.loc[:, 'selected_text_out'] = output
    df_valid.to_csv(f'fold_{fold}.csv')
    

In [None]:
for i in range(5):
    final_evaluate_fn(i)

In [None]:
print('done')