In [None]:
import tokenizers
import transformers
import string
from tqdm.autonotebook import tqdm
from transformers import BertTokenizer,BertConfig,RobertaConfig
import re,gc

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.model_selection import train_test_split


import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Any results you write to the current directory are saved as output.

In [None]:
data_path = r"/kaggle/input/tweet-sentiment-extraction/"
train_data = pd.read_csv(os.path.join(data_path,'train.csv'))
train_data.dropna(inplace=True)
test_data = pd.read_csv(os.path.join(data_path,'test.csv'))

In [None]:
train_data['text'] = train_data['text'].apply(lambda x:x.strip())
test_data['text'] = test_data['text'].apply(lambda x:x.strip())

In [None]:
model_path = r"/kaggle/input/robertabase/"
cf = transformers.RobertaConfig.from_json_file(os.path.join(model_path,'config.json'))
MAX_LEN = 192
TRAIN_BATCH_SIZE = 64
VALID_BATCH_SIZE = 8
EPOCHS = 3
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
TOKENIZER = tokenizers.ByteLevelBPETokenizer(
    vocab_file=f"{model_path}/vocab.json", 
    merges_file=f"{model_path}/merges.txt", 
    lowercase=True,
    add_prefix_space=True)

## Tweet Dataset

In [None]:
class Tweet_Dataset:
    def __init__(self, raw_text, sentiment, selected_text):
        self.raw_text = raw_text
        self.sentiment = sentiment
        self.selected_text = selected_text
        
        self.tokenizer = TOKENIZER
        self.max_len = MAX_LEN
        
        
    def __len__(self):
        return len(self.raw_text)
    
    def __getitem__(self,idx):
        dataset = self.preprocess_text(self.raw_text[idx],self.sentiment[idx],self.selected_text[idx])
        return dataset
        
        
    def preprocess_text(self, tweet_text, sent, s_text):
        #tweet_text = tweet_text.decode('utf-8')
        #sent = sent.decode('utf-8')
        #s_text = s_text.decode('utf-8')
        
        # generate corresponding text
        tweet_text = " ".join(str(tweet_text).split())
        s_text = " ".join(str(s_text).split())
        s_len = len(s_text)
        
        
        # get the start position and end position based on raw text
        idx_start, idx_end = None, None
        for idx in (i for i,e in enumerate(tweet_text) if e == s_text[0]): 
            # Using raw tweet can avoid [CLS] and [SEP]
            if tweet_text[idx:idx+s_len] == s_text:
                idx_start = idx
                idx_end = idx + s_len
                break      
        target_position_list = [0] * len(tweet_text)
        if idx_start!=None and idx_end!=None:
            for char_idx in range(idx_start, idx_end):
                target_position_list[char_idx] = 1 # [0,0,0,1,1,1,1,0,0,0]
                
        
        
        # Get start and end index
        encode_tweet = self.tokenizer.encode(tweet_text)
        input_ids_ori, input_offsets = encode_tweet.ids, encode_tweet.offsets
        target_ids = []
        for e, (o1, o2) in enumerate(input_offsets):
            if sum(target_position_list[o1:o2]) >0:
                target_ids.append(e)
        tar_st = target_ids[0]
        tar_end = target_ids[-1]
        
        
        # Create new text ids by Combining sentiment with tweet ids
        sentiment_map = {'positive': 1313, 'negative': 2430, 'neutral': 7974}
        
        # [CLS] +sentiment + [SEP] + tweet ids + [SEP]
        input_ids = [0] + [sentiment_map[sent]] + [2] + [2] + input_ids_ori + [2]
        input_mask = [0] * (len(input_ids))
        input_type_ids = [0] * (len(input_ids))
        input_offsets = [(0,0)] * 4 + input_offsets + [(0,0)]
        tar_st += 4
        tar_end += 4
        
        
        
        padding_len = self.max_len - len(input_ids)
        if padding_len >0:
            input_ids = input_ids + [0] * padding_len
            input_mask = input_mask + [0] * padding_len
            input_type_ids = input_type_ids +  [0] * padding_len
            input_offsets = input_offsets + [(0, 0)] * padding_len
            
        else:
            pass
        
        return {
            "ids":torch.tensor(input_ids , dtype=torch.long),
            "mask":torch.tensor(input_mask , dtype=torch.long),
            "token_type_ids":torch.tensor(input_type_ids , dtype=torch.long),
            "target_start":torch.tensor(tar_st, dtype=torch.long),
            "target_end":torch.tensor(tar_end, dtype=torch.long),
            "tweet" : tweet_text,
            "sentiment" : sent,
            "selected_text": s_text,
            "offsets":torch.tensor(input_offsets,dtype=torch.long)
            
        }
    
# Split Dataset 
def create_data_loader(dataset,use_gpu=True):
    tr_df, val_df = train_test_split(dataset,test_size=0.1,stratify=dataset['sentiment'])

    # Generate Dataset
    tr_data = Tweet_Dataset(raw_text=tr_df['text'].values,
                             sentiment=tr_df['sentiment'].values,
                             selected_text=tr_df['selected_text'].values)
    val_data = Tweet_Dataset(raw_text=val_df['text'].values,
                             sentiment=val_df['sentiment'].values,
                             selected_text=val_df['selected_text'].values)

    # Generate Dataloader
    tr_ = torch.utils.data.DataLoader(tr_data,batch_size=TRAIN_BATCH_SIZE,pin_memory =use_gpu,shuffle=True)
    val_ = torch.utils.data.DataLoader(val_data,batch_size=VALID_BATCH_SIZE,pin_memory =use_gpu,shuffle=True)
    return tr_,val_


In [None]:
# Model
class HighWay_Model(nn.Module):
    def __init__(self,input_size, gate_bias = -1):
        super().__init__()
        self.normal_layer = nn.Linear(input_size, input_size)
        self.gate_layer = nn.Linear(input_size,input_size)
        self.gate_layer.bias.data.fill_(gate_bias)
        
    def forward(self,x):
        norm_x = F.relu(self.normal_layer(x))
        gate_x = torch.softmax(self.gate_layer(x), dim=0)
        gate_norm = torch.mul(norm_x,gate_x)
        gate_input = torch.mul((1-gate_x),x)
        return torch.add(gate_norm,gate_input)
    
    



class roBerta_Model(transformers.BertPreTrainedModel):
    def __init__(self):
        super().__init__(cf)
        self.bert = transformers.RobertaModel.from_pretrained(model_path,config=cf)
        self.dropout = nn.Dropout(0.3)
       
        self.fc1 = nn.Linear(768,2)
        torch.nn.init.xavier_normal_(self.fc1.weight)

        
    def forward(self,ids,mask,token_type_ids):
        seq_output, pooled_output = self.bert(ids, attention_mask = mask, token_type_ids=token_type_ids)
        # [batch_size, num_tokens, 768]
        # [batch_size, num_tokens,2]
        # [batch_size, num_tokens,1] [batch_size, num_tokens,1]
      
        
        h_vec = self.dropout(self.fc1(seq_output))
        
        st_logits,end_logits = h_vec.split(1,dim=-1)
        
        return st_logits.squeeze() , end_logits.squeeze()
    
    '''def cnn_decoder(self,x):
        # x -> [batch,num_token, seq_len]
        x = x.unsqueeze(1)
        convx = F.leaky_relu(self.cnn(x)).squeeze(3)
        pool_x = F.avg_pool1d(convx, convx.shape[2]).squeeze(2) 
        #cat = self.dropout(torch.cat(pool_x,dim=1))
        return self.highway(pool_x)
    
    def cnn_decoder2(self,x):
        # x -> [batch,num_token, seq_len]
        x = x.unsqueeze(1)
        convx = F.leaky_relu(self.cnn2(x)).squeeze(3)
        pool_x = F.avg_pool1d(convx, convx.shape[2]).squeeze(2)
        #cat = self.dropout(torch.cat(pool_x,dim=1))
        return self.highway2(pool_x)
    '''
    
    

In [None]:
model = roBerta_Model().to(device)


In [None]:
# Loss Function

def loss_fn(o1, o2, t1, t2):
    loss_fct = nn.CrossEntropyLoss()
    l1 = loss_fct (o1, t1)
    l2 = loss_fct (o2, t2)
    return l1+l2

def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



def train_fn(data_loader, model, opt, device, scheduler):
    model.train()
    
    losses = AverageMeter()
    
    tk0 = tqdm(data_loader,total=len(data_loader))
    for b, data in enumerate(tk0):
        ids = data['ids'].to(device)
        token_type_ids = data['token_type_ids'].to(device)
        mask = data['mask'].to(device)
        target_start = data['target_start'].to(device)
        target_end = data['target_end'].to(device)
        
        opt.zero_grad()
        o1,o2 = model(ids = ids,
                        mask = mask,
                        token_type_ids = token_type_ids
                       )
        
        loss = loss_fn(o1,o2,target_start,target_end)
        loss.backward()
        opt.step()
        #scheduler.step()
        losses.update(loss.item(),ids.size(0))
        tk0.set_postfix(loss=losses.avg)
        
def eval_fn(data_loader, model, device,test=False):
    model.eval()
    losses = AverageMeter()
    jac_vals = AverageMeter()
    fin_selected_text = []
    tk0 = tqdm(data_loader, total=len(data_loader))
    
    for b,data in enumerate(tk0):
        ids = data['ids'].to(device,dtype=torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
        mask = data['mask'].to(device, dtype=torch.long)
        
        target_start = data['target_start'].to(device, dtype=torch.long)
        target_end = data['target_end'].to(device, dtype=torch.long)
        
        tweet = data['tweet']
        sentiment = data['sentiment']
        selected_text = data['selected_text'] 
        
        offsets = data['offsets'].numpy()
        
        o1, o2 = model(
            ids = ids,
            mask = mask,
            token_type_ids = token_type_ids
        )
        loss = loss_fn(o1,o2,target_start,target_end)
        o_start = torch.softmax(o1,dim=1).cpu().detach().numpy()
        o_end = torch.softmax(o2,dim=1).cpu().detach().numpy()
        
        jacc_scores = []
        batch_target_outputs = []
        
        for idx, batch_tweet in enumerate(tweet):
            batch_selected_text = selected_text[idx]
            batch_sentiment = sentiment[idx]
            sent_val = sentiment[idx]
            offset_var = offsets[idx]
            
            
            # Function #
            idx_st = np.argmax(o_start[idx,:])
            idx_end = np.argmax(o_end[idx,:])
            
            if idx_st>idx_end:
                idx_end = idx_st
                
            filter_output = ""
            if sent_val == 'neutral' or len(batch_tweet.split()) <2:
                filter_output = batch_tweet
            else:
                for ix in range(idx_st,idx_end+1):
                    filter_output += batch_tweet[offset_var[ix][0]:offset_var[ix][1]]
                    if (ix+1) < len(offset_var) and offset_var[ix][1] < offset_var[ix+1][0]:
                        filter_output += " "
            jac = jaccard(batch_selected_text.strip(),filter_output.strip())
            jacc_scores.append(jac)
            batch_target_outputs.append(filter_output.strip())
            
            
        jac_vals.update(np.mean(jacc_scores),ids.size(0))
        losses.update(loss.item(), ids.size(0))
        tk0.set_postfix(loss=losses.avg, jaccard=jac_vals.avg)
        
        fin_selected_text += batch_target_outputs
        
    return jac_vals.avg,fin_selected_text       
        

In [None]:
no_decay = ['bias','LayerNorm.bias','LayerNorm.weight']
opt_params = [
    {'params':[p for n,p in model.named_parameters() if not any(nd in n for nd in no_decay)],"weight_decay":0.0},
    {'params':[p for n,p in model.named_parameters() if any(nd in n for nd in no_decay)],"weight_decay":0.0}
]

num_train_steps = int((train_data.shape[0]*0.1)/TRAIN_BATCH_SIZE * EPOCHS)
optimizer = transformers.AdamW(opt_params,lr=3e-5)
#scheduler = transformers.get_linear_schedule_with_warmup(optimizer,num_warmup_steps=3,num_training_steps=num_train_steps)

#optimizer = torch.optim.Adam(lstm_model.parameters(),lr=1e-5)
#lmbda = lambda epoch: 0.95
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=1)

In [None]:
best_jaccard = 0
#tr_jaccs = []
#val_jaccs = []
for epoch in range(EPOCHS):
    tr_loader, val_loader = create_data_loader(train_data)
    train_fn(tr_loader,model,optimizer,device,scheduler)
    #tr_jacc = eval_fn(tr_loader,model,device)
    jacc,val_text = eval_fn(val_loader,model,device)
    
    
    #tr_jaccs.append(tr_jacc)
    #val_jaccs.append(jacc)
    print(f"Epoch: {epoch+1} | Validation Jaccard score : {jacc:0.3f}")
    print(f"\tValidation text:{val_text[:3]}")
    del val_text; gc.collect()
    scheduler.step(jacc)
    if jacc>best_jaccard:
        torch.save(model.state_dict(),r'tweet_model_v3.pt')
        best_model = model
        best_jaccard = jacc

## Test Part

order_dict_param = torch.load(r"/kaggle/input/trained-roberta-model/tweet_model_roBERTa_base.pt")
for name, param in model.named_parameters():
    if name in order_dict_param.keys():
        param.data.copy_(order_dict_param[name])

In [None]:
ts_data = Tweet_Dataset(raw_text = test_data['text'].values, 
          sentiment=test_data['sentiment'].values, 
          selected_text = test_data['text'].values)

ts_dataloader = torch.utils.data.DataLoader(ts_data, batch_size=16, pin_memory=False,shuffle=False)
jacc,fin_output = eval_fn(ts_dataloader,best_model,device)

In [None]:
def post_process(selected):
    return " ".join(set(selected.lower().split()))
os.listdir(data_path)
sub = pd.read_csv(os.path.join(data_path,'sample_submission.csv'))
sub['selected_text'] = fin_output
sub['selected_text'] = sub['selected_text'].apply(lambda x:post_process(x))
sub.to_csv("submission.csv",index=False)