In [1]:
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm.auto import tqdm

In [2]:
#Generate Binary mask to find sub-text within text
def get_mask(text, subtext):
    loc = text.find(subtext)
    prefix = text[:loc]
    start = len(prefix.split())
    end   = start + len(subtext.split())
    mask = [0]*len(text.split())
    mask[start:end] = [1]*(end-start)
    return mask
#Get Data
class TweetaSet(Dataset):
    def __init__(self, filename, tokenizer, max_len, frac=1.0):
        self.df = pd.read_csv(filename,header=0).dropna().sample(frac=frac)
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.df)
    def __getitem__(self, item):
        textID, context, answer, question = self.df.iloc[item]
        #Pre-encode context to get attention mask (masks out unknown words)
        context_enc = self.tokenizer.encode_plus(
            context,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        #get binary mask and bitwise and with attention mask to get rid of unknown words
        mask = get_mask(context, answer)
        mask += [0]*max(0,self.max_len-len(mask)) 
        mask = torch.Tensor(mask).int() #& context_enc["attention_mask"].int()
        start = torch.nonzero(mask)[0][0]
        end = start + torch.sum(mask)
        
        #encode question and context together
        encoding = self.tokenizer.encode_plus(
            question,context,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'start_positions' : start,
            'end_positions'   : end
        }

In [3]:
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
#standard huggingface finetuning training loop (see website for more details)
def train(filename, max_len, batch_size, epochs, frac=1.0, device='cuda' if torch.cuda.is_available() else 'cpu'):
    tokenizer = DistilBertTokenizer.from_pretrained('/kaggle/input/distilbert-cased-for-qa')
    model = DistilBertForQuestionAnswering.from_pretrained('/kaggle/input/distilbert-cased-for-qa').to(device)
    for param in model.base_model.parameters():
        param.requires_grad = False
    model.train()
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
    scheduler = get_linear_schedule_with_warmup(optimizer, 5 if epochs >= 10 else 0, epochs)
    ds = TweetaSet(filename, tokenizer, max_len, frac=frac)
    dl = DataLoader(ds, batch_size=batch_size)
    for _ in range(epochs):
        for batch in tqdm(dl):
            batch = {k:v.to(device) for k,v in batch.items()}
            optimizer.zero_grad()
            outputs = model(**batch)
            loss,_,_ = outputs
            loss.backward()
            optimizer.step()
        scheduler.step()
    model.eval()
    return model,tokenizer

In [4]:
model,tokenizer = train("/kaggle/input/tweet-sentiment-extraction/train.csv", 35, 100, 50, frac=0.15)

HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))




In [5]:
#get test data
class TweetaTestSet(Dataset):
    def __init__(self, filename, tokenizer, max_len, frac=1.0):
        self.df = pd.read_csv(filename).dropna().sample(frac=frac)
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.df)
    def __getitem__(self, item):
        textID, context, question = self.df.iloc[item]
        encoded = self.tokenizer.encode_plus(
            question,context,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'id': textID,
            'context': context,
            'input_ids': encoded['input_ids'].flatten(),
            'attention_mask': encoded['attention_mask'].flatten()
        }

In [6]:
#generate output data
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ds = TweetaTestSet("/kaggle/input/tweet-sentiment-extraction/test.csv", tokenizer, 35)
df = pd.DataFrame(columns=["textID", "selected_text"])
for batch in DataLoader(ds, batch_size=100):
    start_logits, end_logits = model(batch['input_ids'].to(device), batch['attention_mask'].to(device))
    start = torch.argmax(F.log_softmax(start_logits, dim=1),dim=1)
    end   = torch.argmax(F.log_softmax(end_logits, dim=1),dim=1)
#     print((start,end,[" ".join(t.split()[s:e]) for t,s,e in zip(batch["context"], start.int(), end.int())]))
    df2 = pd.DataFrame({
        "textID": batch["id"],
        "selected_text": [" ".join(t.split()[s:e]) for t,s,e in zip(batch["context"], start.int(), end.int())]
    })
    df = df.append(df2)
df.to_csv("/kaggle/working/submission.csv", index=False)