In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
%%writefile Model.py
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F

class TweetModel(transformers.BertPreTrainedModel):
    def __init__(self, model_path, conf):
        super(TweetModel, self).__init__(conf)
        self.roberta = transformers.RobertaModel.from_pretrained(model_path, config=conf)
        self.drop_out = nn.Dropout(0.1)
        self.l0 = nn.Linear(768 * 2, 2) #762


    def forward(self,input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
        _,_, out = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        out = torch.cat((out[-1], out[-2]), dim=-1)
        out = self.drop_out(out)
        logits = self.l0(out)

        start_logits, end_logits = logits.split(1, dim=-1)

        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return start_logits, end_logits

In [None]:
%%writefile dataset.py

import torch


def find_start_and_end(tweet, selected_text):
    len_st = len(selected_text)
    start = None
    end = None
    for ind in (i for i, e in enumerate(tweet) if e == selected_text[0]):
        if tweet[ind: ind+len_st] == selected_text:
            start = ind
            end = ind + len_st - 1
            break
    return start, end

def process_with_offsets(args, tweet, selected_text, sentiment, tokenizer):

    start_index, end_index = find_start_and_end(tweet, selected_text)

    char_targets = [0]*len(tweet)
    if start_index != None and end_index != None:
        for ct in range(start_index, end_index+1):
            char_targets[ct] = 1
    
    encoded = tokenizer.encode_plus(
                    sentiment,
                    tweet,
                    max_length=args.max_seq_len,
                    pad_to_max_length=True,
                    return_token_type_ids=True,
                    return_offsets_mapping=True
                )

    target_idx = []
    for j, (offset1, offset2) in enumerate(encoded["offset_mapping"]):
        if j > 3:
            if sum(char_targets[offset1:offset2]) > 0:
                target_idx.append(j)

    encoded["start_position"] = target_idx[0]
    encoded["end_position"] = target_idx[-1]
    encoded["tweet"] = tweet
    encoded["selected_text"] = selected_text
    encoded["sentiment"] = sentiment

    return encoded


class TweetDataset:
    def __init__(self, args, tokenizer, df, mode="train", fold=0):
        
        self.mode = mode

        if self.mode == "train":
            df = df[~df.kfold.isin([fold])].dropna()
            self.tweet = df.text.values
            self.sentiment = df.sentiment.values
            self.selected_text = df.selected_text.values
        
        elif self.mode == "valid":
            df = df[df.kfold.isin([fold])].dropna()
            self.tweet = df.text.values
            self.sentiment = df.sentiment.values
            self.selected_text = df.selected_text.values
        
        self.tokenizer = tokenizer
        self.args = args
    
    def __len__(self):
        return len(self.tweet)

    def __getitem__(self, item):

        tweet = str(self.tweet[item])
        selected_text = str(self.selected_text[item])
        sentiment = str(self.sentiment[item])
        
        features = process_with_offsets(
                        args=self.args, 
                        tweet=tweet, 
                        selected_text=selected_text, 
                        sentiment=sentiment, 
                        tokenizer=self.tokenizer
                    )
        
        return {
            "input_ids":torch.tensor(features["input_ids"], dtype=torch.long),
            "token_type_ids":torch.tensor(features["token_type_ids"], dtype=torch.long),
            "attention_mask":torch.tensor(features["attention_mask"], dtype=torch.long),
            "start_position":torch.tensor(features["start_position"],dtype=torch.long),
            "end_position":torch.tensor(features["end_position"], dtype=torch.long),

            "offsets":features["offset_mapping"],
            "tweet":features["tweet"],
            "selected_text":features["selected_text"],
            "sentiment":features["sentiment"]
        }

In [None]:
%%writefile metric.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def jaccard(str1, str2):
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    if (len(a)==0) & (len(b)==0): 
        return 0.5
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))


def to_list(tensor):
    return tensor.detach().cpu().tolist()

def calculate_jaccard_score(features_dict, start_logits, end_logits, tokenizer):

    binput_ids = to_list(features_dict["input_ids"])
    btweet = features_dict["tweet"]
    bselected_text = features_dict["selected_text"]
    bsentiment = features_dict["sentiment"]
    boffsets = features_dict["offsets"]

    bstart_logits = np.argmax(F.softmax(start_logits, dim=1).cpu().data.numpy(), axis=1)
    bend_logits = np.argmax(F.softmax(end_logits, dim=1).cpu().data.numpy(), axis=1)

    jac_list = []

    for i in range(len(btweet)):

        idx_start = bstart_logits[i]
        idx_end = bend_logits[i]
        offsets = boffsets[i]
        input_ids = binput_ids[i]
        tweet = btweet[i]
        selected_text = bselected_text[i]

        if idx_end < idx_start:
            idx_end = idx_start

        filtered_output = tokenizer.decode(input_ids[idx_start:idx_end+1], skip_special_tokens=True)

        if bsentiment[i] == "neutral" or len(tweet.split()) < 2:
            filtered_output = tweet
        
        jac = jaccard(selected_text.strip(), filtered_output.strip())

        jac_list.append(jac)

    return np.mean(jac_list)

In [None]:
%%writefile utils.py

import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model=None):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            if model:
                self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            if model:
                self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [None]:
%%writefile main_tpu.py

import argparse
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset
import random
import re
import json
import transformers
from transformers import (
    AdamW,
    get_linear_schedule_with_warmup
)

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

from Model import TweetModel
from dataset import TweetDataset
from metric import calculate_jaccard_score
import utils


import warnings
warnings.filterwarnings('ignore')


def to_list(tensor):
    return tensor.detach().cpu().tolist()

class AverageMeter(object):
    """Computes and stores the average and current values"""
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def get_position_accuracy(logits, labels):
    predictions = np.argmax(F.softmax(logits, dim=1).cpu().data.numpy(), axis=1)
    labels = labels.cpu().data.numpy()
    total_num = 0
    sum_correct = 0
    for i in range(len(labels)):
        if labels[i] >= 0:
            total_num += 1
            if predictions[i] == labels[i]:
                sum_correct += 1
    if total_num == 0:
        total_num = 1e-7
    return np.float32(sum_correct) / total_num, total_num

def reduce_fn(vals):
    return sum(vals) / len(vals)

def loss_fn(preds, labels):
    start_preds, end_preds = preds
    start_labels, end_labels = labels

    start_loss = nn.CrossEntropyLoss(ignore_index=-1)(start_preds, start_labels)
    end_loss = nn.CrossEntropyLoss(ignore_index=-1)(end_preds, end_labels)
    return start_loss, end_loss


def train(args, train_loader, model, device, optimizer,scheduler, epoch, f):
    total_loss = AverageMeter()
    losses1 = AverageMeter() # start
    losses2 = AverageMeter() # end
    accuracies1 = AverageMeter() # start
    accuracies2 = AverageMeter() # end

    model.train()

    t = tqdm(train_loader, disable=not xm.is_master_ordinal())
    for step, d in enumerate(t):
        
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        token_type_ids = d["token_type_ids"].to(device)
        start_position = d["start_position"].to(device)
        end_position = d["end_position"].to(device)

        model.zero_grad()

        logits1, logits2 = model(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids, 
            position_ids=None, 
            head_mask=None
        )

        y_true = (start_position, end_position)
        loss1, loss2 = loss_fn((logits1, logits2), (start_position, end_position))
        loss = loss1 + loss2

        acc1, n_position1 = get_position_accuracy(logits1, start_position)
        acc2, n_position2 = get_position_accuracy(logits2, end_position)

        total_loss.update(loss.item(), n_position1)
        losses1.update(loss1.item(), n_position1)
        losses2.update(loss2.item(), n_position2)
        accuracies1.update(acc1, n_position1)
        accuracies2.update(acc2, n_position2)

        
        loss.backward()
        xm.optimizer_step(optimizer)
        scheduler.step()
        print_loss = xm.mesh_reduce("loss_reduce", total_loss.avg, reduce_fn)
        print_acc1 = xm.mesh_reduce("acc1_reduce", accuracies1.avg, reduce_fn)
        print_acc2 = xm.mesh_reduce("acc2_reduce", accuracies2.avg, reduce_fn)
        t.set_description(f"Train E:{epoch+1} - Loss:{print_loss:0.2f} - acc1:{print_acc1:0.2f} - acc2:{print_acc2:0.2f}")


    log_ = f"Epoch : {epoch+1} - train_loss : {total_loss.avg} - \n \
    train_loss1 : {losses1.avg} - train_loss2 : {losses2.avg} - \n \
    train_acc1 : {accuracies1.avg} - train_acc2 : {accuracies2.avg}"

    f.write(log_ + "\n\n")
    f.flush()
    
    return total_loss.avg

def valid(args, valid_loader, model, device, tokenizer, epoch, f):
    total_loss = AverageMeter()
    losses1 = AverageMeter() # start
    losses2 = AverageMeter() # end
    accuracies1 = AverageMeter() # start
    accuracies2 = AverageMeter() # end

    jaccard_scores = AverageMeter()

    model.eval()

    with torch.no_grad():
        t = tqdm(valid_loader, disable=not xm.is_master_ordinal())
        for step, d in enumerate(t):
            
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            token_type_ids = d["token_type_ids"].to(device)
            start_position = d["start_position"].to(device)
            end_position = d["end_position"].to(device)

            logits1, logits2 = model(
                input_ids=input_ids, 
                attention_mask=attention_mask, 
                token_type_ids=token_type_ids, 
                position_ids=None, 
                head_mask=None
            )

            y_true = (start_position, end_position)
            loss1, loss2 = loss_fn((logits1, logits2), (start_position, end_position))
            loss = loss1 + loss2

            acc1, n_position1 = get_position_accuracy(logits1, start_position)
            acc2, n_position2 = get_position_accuracy(logits2, end_position)

            total_loss.update(loss.item(), n_position1)
            losses1.update(loss1.item(), n_position1)
            losses2.update(loss2.item(), n_position2)
            accuracies1.update(acc1, n_position1)
            accuracies2.update(acc2, n_position2)

            jac_score = calculate_jaccard_score(features_dict=d, start_logits=logits1, end_logits=logits2, tokenizer=tokenizer)

            jaccard_scores.update(jac_score)

            print_loss = xm.mesh_reduce("vloss_reduce", total_loss.avg, reduce_fn)
            print_jac = xm.mesh_reduce("jac_reduce", jaccard_scores.avg, reduce_fn)

            t.set_description(f"Eval E:{epoch+1} - Loss:{print_loss:0.2f} - Jac:{print_jac:0.2f}")

    #print("Valid Jaccard Score : ", jaccard_scores.avg)
    log_ = f"Epoch : {epoch+1} - valid_loss : {total_loss.avg} - \n\
    valid_loss1 : {losses1.avg} - \valid_loss2 : {losses2.avg} - \n\
    valid_acc1 : {accuracies1.avg} - \valid_acc2 : {accuracies2.avg} "

    f.write(log_ + "\n\n")
    f.flush()
    
    return jaccard_scores.avg

def main():

    parser = argparse.ArgumentParser()

    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument("--max_seq_len", type=int, default=192)
    parser.add_argument("--fold_index", type=int, default=0)
    parser.add_argument("--learning_rate", type=float, default=0.00002)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--model_path", type=str, default="roberta-base")
    parser.add_argument("--output_dir", type=str, default="")
    parser.add_argument("--exp_name", type=str, default="")
    parser.add_argument("--spt_path", type=str, default="")
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()

    # Setting seed
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    model_path = args.model_path
    config = transformers.RobertaConfig.from_pretrained(model_path)
    config.output_hidden_states = True
    tokenizer = transformers.RobertaTokenizerFast.from_pretrained(model_path, do_lower_case=True)
    
    MX = TweetModel(model_path, config)

    train_df = pd.read_csv(f"../input/tweet-create-folds/train_5folds.csv")

    args.save_path = os.path.join(args.output_dir, args.exp_name)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    f = open(os.path.join(args.save_path, f"log_f_{args.fold_index}.txt"), "w")

    num_train_dpoints = int((len(train_df)/5) * 4)

    def run():

        torch.manual_seed(seed)

        device = xm.xla_device()
        model = MX.to(device)

        # DataLoaders
        train_dataset = TweetDataset(
            args=args,
            df=train_df,
            mode="train",
            fold=args.fold_index,
            tokenizer=tokenizer
        )
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True
        )
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            sampler=train_sampler,
            drop_last=False,
            num_workers=2
        )

        valid_dataset = TweetDataset(
            args=args,
            df=train_df,
            mode="valid",
            fold=args.fold_index,
            tokenizer=tokenizer
        )
        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False
        )
        valid_loader = DataLoader(
            valid_dataset,
            batch_size=args.batch_size,
            sampler=valid_sampler,
            num_workers=1,
            drop_last=False
        )

        param_optimizer = list(model.named_parameters())
        no_decay = [
            "bias",
            "LayerNorm.bias",
            "LayerNorm.weight"
        ]
        optimizer_parameters = [
            {
                'params': [
                    p for n, p in param_optimizer if not any(
                        nd in n for nd in no_decay
                    )
                ], 
            'weight_decay': 0.001
            },
            {
                'params': [
                    p for n, p in param_optimizer if any(
                        nd in n for nd in no_decay
                    )
                ], 
                'weight_decay': 0.0
            },
        ]

        

        num_train_steps = int(
            num_train_dpoints / args.batch_size / xm.xrt_world_size() * args.epochs
        )

        optimizer = AdamW(
            optimizer_parameters,
            lr=args.learning_rate * xm.xrt_world_size()
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=num_train_steps
        )

        xm.master_print("Training is Starting ...... ")
        best_jac = 0
        #early_stopping = utils.EarlyStopping(patience=2, mode="max", verbose=True)

        for epoch in range(args.epochs):
            para_loader = pl.ParallelLoader(train_loader, [device])
            train_loss = train(
                args, 
                para_loader.per_device_loader(device),
                model,
                device,
                optimizer,
                scheduler,
                epoch,
                f
            )

            para_loader = pl.ParallelLoader(valid_loader, [device])
            valid_jac = valid(
                args, 
                para_loader.per_device_loader(device),
                model,
                device,
                tokenizer,
                epoch,
                f
            )

            jac = xm.mesh_reduce("jac_reduce", valid_jac, reduce_fn)
            xm.master_print(f"**** Epoch {epoch+1} **==>** Jaccard = {jac}")

            log_ = f"**** Epoch {epoch+1} **==>** Jaccard = {jac}"

            f.write(log_ + "\n\n")

            if jac > best_jac:
                xm.master_print("**** Model Improved !!!! Saving Model")
                xm.save(model.state_dict(), os.path.join(args.save_path, f"fold_{args.fold_index}"))
                best_jac = jac
            
            #early_stopping(jac)
            
            #if early_stopping.early_stop:
            #    print("Early stopping")
            #    break


    def _mp_fn(rank, flags):
        torch.set_default_tensor_type('torch.FloatTensor')
        a = run()
    
    FLAGS={}
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

if __name__ == "__main__":
    main()

In [None]:
!python main_tpu.py --fold_index=0 \
                  --model_path="roberta-base" \
                  --output_dir="roberta-base" \
                  --exp_name="base_seed42" \
                  --batch_size=64 \
                  --learning_rate=2e-5 \
                  --seed=42

In [None]:
!python main_tpu.py --fold_index=1 \
                  --model_path="roberta-base" \
                  --output_dir="roberta-base" \
                  --exp_name="base_seed42" \
                  --batch_size=64 \
                  --learning_rate=2e-5 \
                  --seed=42

In [None]:
!python main_tpu.py --fold_index=2 \
                  --model_path="roberta-base" \
                  --output_dir="roberta-base" \
                  --exp_name="base_seed42" \
                  --batch_size=64 \
                  --learning_rate=2e-5 \
                  --seed=42

In [None]:
!python main_tpu.py --fold_index=3 \
                  --model_path="roberta-base" \
                  --output_dir="roberta-base" \
                  --exp_name="base_seed42" \
                  --batch_size=64 \
                  --learning_rate=2e-5 \
                  --seed=42

In [None]:
!python main_tpu.py --fold_index=4 \
                  --model_path="roberta-base" \
                  --output_dir="roberta-base" \
                  --exp_name="base_seed42" \
                  --batch_size=64 \
                  --learning_rate=2e-5 \
                  --seed=42

##  If you like it upvote it , Thank you ðŸ™‚