In [None]:
!pip install transformers

In [None]:
!nvidia-smi

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
import sklearn
from sklearn.model_selection import train_test_split
import re

import transformers
from transformers import BertTokenizer, BertModel

In [None]:
train = pd.read_csv('train.csv')

In [None]:
train

In [None]:
train.text = train.text.str.lower()

In [None]:
nltk.download('stopwords')

In [None]:
import string
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import TweetTokenizer

tknzr = TweetTokenizer(strip_handles=True)
stop_words = set(stopwords.words('english'))
corpus = []

def clean_data(text):
    # special characters
    text = re.sub(r"\x89Û_", "", text)
    text = re.sub(r"\x89ÛÒ", "", text)
    text = re.sub(r"\x89ÛÓ", "", text)
    text = re.sub(r"\x89ÛÏWhen", "When", text)
    text = re.sub(r"\x89ÛÏ", "", text)
    text = re.sub(r"China\x89Ûªs", "China's", text)
    text = re.sub(r"let\x89Ûªs", "let's", text)
    text = re.sub(r"\x89Û÷", "", text)
    text = re.sub(r"\x89Ûª", "", text)
    text = re.sub(r"\x89Û\x9d", "", text)
    text = re.sub(r"å_", "", text)
    text = re.sub(r"\x89Û¢", "", text)
    text = re.sub(r"\x89Û¢åÊ", "", text)
    text = re.sub(r"fromåÊwounds", "from wounds", text)
    text = re.sub(r"åÊ", "", text)
    text = re.sub(r"åÈ", "", text)
    text = re.sub(r"JapÌ_n", "Japan", text)    
    text = re.sub(r"Ì©", "e", text)
    text = re.sub(r"å¨", "", text)
    text = re.sub(r"SuruÌ¤", "Suruc", text)
    text = re.sub(r"åÇ", "", text)
    text = re.sub(r"å£3million", "3 million", text)
    text = re.sub(r"åÀ", "", text)
    
    emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F"  # emoticons
                           u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                           u"\U0001F680-\U0001F6FF"  # transport & map symbols
                           u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    text = emoji_pattern.sub(r'', text)
    
    # remove numbers
    text = re.sub(r'[0-9]', '', text)
    
    # remove punctuation and special chars (keep '!')
    for p in string.punctuation.replace('!', ''):
        text = text.replace(p, '')
        
    # remove urls
    text = re.sub(r'http\S+', '', text)
    
    # tokenize
    text = tknzr.tokenize(text)
    
    # remove stopwords
#     text = [w.lower() for w in text if not w in stop_words]
#     corpus.append(text)
    
    # join back
    text = ' '.join(text)
    
    return text

In [None]:
# Thanks to https://www.kaggle.com/rftexas/text-only-kfold-bert
abbreviations = {
    "$" : " dollar ",
    "€" : " euro ",
    "4ao" : "for adults only",
    "a.m" : "before midday",
    "a3" : "anytime anywhere anyplace",
    "aamof" : "as a matter of fact",
    "acct" : "account",
    "adih" : "another day in hell",
    "afaic" : "as far as i am concerned",
    "afaict" : "as far as i can tell",
    "afaik" : "as far as i know",
    "afair" : "as far as i remember",
    "afk" : "away from keyboard",
    "app" : "application",
    "approx" : "approximately",
    "apps" : "applications",
    "asap" : "as soon as possible",
    "asl" : "age, sex, location",
    "atk" : "at the keyboard",
    "ave." : "avenue",
    "aymm" : "are you my mother",
    "ayor" : "at your own risk", 
    "b&b" : "bed and breakfast",
    "b+b" : "bed and breakfast",
    "b.c" : "before christ",
    "b2b" : "business to business",
    "b2c" : "business to customer",
    "b4" : "before",
    "b4n" : "bye for now",
    "b@u" : "back at you",
    "bae" : "before anyone else",
    "bak" : "back at keyboard",
    "bbbg" : "bye bye be good",
    "bbc" : "british broadcasting corporation",
    "bbias" : "be back in a second",
    "bbl" : "be back later",
    "bbs" : "be back soon",
    "be4" : "before",
    "bfn" : "bye for now",
    "blvd" : "boulevard",
    "bout" : "about",
    "brb" : "be right back",
    "bros" : "brothers",
    "brt" : "be right there",
    "bsaaw" : "big smile and a wink",
    "btw" : "by the way",
    "bwl" : "bursting with laughter",
    "c/o" : "care of",
    "cet" : "central european time",
    "cf" : "compare",
    "cia" : "central intelligence agency",
    "csl" : "can not stop laughing",
    "cu" : "see you",
    "cul8r" : "see you later",
    "cv" : "curriculum vitae",
    "cwot" : "complete waste of time",
    "cya" : "see you",
    "cyt" : "see you tomorrow",
    "dae" : "does anyone else",
    "dbmib" : "do not bother me i am busy",
    "diy" : "do it yourself",
    "dm" : "direct message",
    "dwh" : "during work hours",
    "e123" : "easy as one two three",
    "eet" : "eastern european time",
    "eg" : "example",
    "embm" : "early morning business meeting",
    "encl" : "enclosed",
    "encl." : "enclosed",
    "etc" : "and so on",
    "faq" : "frequently asked questions",
    "fawc" : "for anyone who cares",
    "fb" : "facebook",
    "fc" : "fingers crossed",
    "fig" : "figure",
    "fimh" : "forever in my heart", 
    "ft." : "feet",
    "ft" : "featuring",
    "ftl" : "for the loss",
    "ftw" : "for the win",
    "fwiw" : "for what it is worth",
    "fyi" : "for your information",
    "g9" : "genius",
    "gahoy" : "get a hold of yourself",
    "gal" : "get a life",
    "gcse" : "general certificate of secondary education",
    "gfn" : "gone for now",
    "gg" : "good game",
    "gl" : "good luck",
    "glhf" : "good luck have fun",
    "gmt" : "greenwich mean time",
    "gmta" : "great minds think alike",
    "gn" : "good night",
    "g.o.a.t" : "greatest of all time",
    "goat" : "greatest of all time",
    "goi" : "get over it",
    "gps" : "global positioning system",
    "gr8" : "great",
    "gratz" : "congratulations",
    "gyal" : "girl",
    "h&c" : "hot and cold",
    "hp" : "horsepower",
    "hr" : "hour",
    "hrh" : "his royal highness",
    "ht" : "height",
    "ibrb" : "i will be right back",
    "ic" : "i see",
    "icq" : "i seek you",
    "icymi" : "in case you missed it",
    "idc" : "i do not care",
    "idgadf" : "i do not give a damn fuck",
    "idgaf" : "i do not give a fuck",
    "idk" : "i do not know",
    "ie" : "that is",
    "i.e" : "that is",
    "ifyp" : "i feel your pain",
    "IG" : "instagram",
    "iirc" : "if i remember correctly",
    "ilu" : "i love you",
    "ily" : "i love you",
    "imho" : "in my humble opinion",
    "imo" : "in my opinion",
    "imu" : "i miss you",
    "iow" : "in other words",
    "irl" : "in real life",
    "j4f" : "just for fun",
    "jic" : "just in case",
    "jk" : "just kidding",
    "jsyk" : "just so you know",
    "l8r" : "later",
    "lb" : "pound",
    "lbs" : "pounds",
    "ldr" : "long distance relationship",
    "lmao" : "laugh my ass off",
    "lmfao" : "laugh my fucking ass off",
    "lol" : "laughing out loud",
    "ltd" : "limited",
    "ltns" : "long time no see",
    "m8" : "mate",
    "mf" : "motherfucker",
    "mfs" : "motherfuckers",
    "mfw" : "my face when",
    "mofo" : "motherfucker",
    "mph" : "miles per hour",
    "mr" : "mister",
    "mrw" : "my reaction when",
    "ms" : "miss",
    "mte" : "my thoughts exactly",
    "nagi" : "not a good idea",
    "nbc" : "national broadcasting company",
    "nbd" : "not big deal",
    "nfs" : "not for sale",
    "ngl" : "not going to lie",
    "nhs" : "national health service",
    "nrn" : "no reply necessary",
    "nsfl" : "not safe for life",
    "nsfw" : "not safe for work",
    "nth" : "nice to have",
    "nvr" : "never",
    "nyc" : "new york city",
    "oc" : "original content",
    "og" : "original",
    "ohp" : "overhead projector",
    "oic" : "oh i see",
    "omdb" : "over my dead body",
    "omg" : "oh my god",
    "omw" : "on my way",
    "p.a" : "per annum",
    "p.m" : "after midday",
    "pm" : "prime minister",
    "poc" : "people of color",
    "pov" : "point of view",
    "pp" : "pages",
    "ppl" : "people",
    "prw" : "parents are watching",
    "ps" : "postscript",
    "pt" : "point",
    "ptb" : "please text back",
    "pto" : "please turn over",
    "qpsa" : "what happens", #"que pasa",
    "ratchet" : "rude",
    "rbtl" : "read between the lines",
    "rlrt" : "real life retweet", 
    "rofl" : "rolling on the floor laughing",
    "roflol" : "rolling on the floor laughing out loud",
    "rotflmao" : "rolling on the floor laughing my ass off",
    "rt" : "retweet",
    "ruok" : "are you ok",
    "sfw" : "safe for work",
     "sk8" : "skate",
    "smh" : "shake my head",
    "sq" : "square",
    "srsly" : "seriously", 
    "ssdd" : "same stuff different day",
    "tbh" : "to be honest",
    "tbs" : "tablespooful",
    "tbsp" : "tablespooful",
    "tfw" : "that feeling when",
    "thks" : "thank you",
    "tho" : "though",
    "thx" : "thank you",
    "tia" : "thanks in advance",
    "til" : "today i learned",
    "tl;dr" : "too long i did not read",
    "tldr" : "too long i did not read",
    "tmb" : "tweet me back",
    "tntl" : "trying not to laugh",
    "ttyl" : "talk to you later",
    "u" : "you",
    "u2" : "you too",
    "u4e" : "yours for ever",
    "utc" : "coordinated universal time",
    "w/" : "with",
    "w/o" : "without",
    "w8" : "wait",
    "wassup" : "what is up",
    "wb" : "welcome back",
    "wtf" : "what the fuck",
    "wtg" : "way to go",
    "wtpa" : "where the party at",
    "wuf" : "where are you from",
    "wuzup" : "what is up",
    "wywh" : "wish you were here",
    "yd" : "yard",
    "ygtr" : "you got that right",
    "ynk" : "you never know",
    "zzz" : "sleeping bored and tired"
}

In [None]:
def convert_abbrev_in_text(text):
    t=[]
    words=text.split()
    t = [abbreviations[w.lower()] if w.lower() in abbreviations.keys() else w for w in words]
    return ' '.join(t) 

In [None]:
train['text']=train['text'].apply(clean_data)
train['text']=train['text'].apply(convert_abbrev_in_text)

In [None]:
train['text_len'] = train.text.apply(lambda x : len(x))

In [None]:
import seaborn as sns

sns.displot(train.text_len)

In [None]:
train, valid = train_test_split(train,test_size=0.2,stratify=train.target)

In [None]:
train.target.value_counts(), train.shape

In [None]:
valid.target.value_counts(), valid.shape

In [None]:
class MyDataset(data.Dataset):
    
    def __init__(self,texts,targets,max_len):
        self.texts = texts
        self.targets = targets
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self,item):
        
        text = str(self.texts[item])

        encoding = self.tokenizer.encode_plus(
            text,
            max_length = self.max_len,
            add_special_tokens = True,
            return_attention_mask = True,
            padding = 'max_length',
            truncation = True,
            return_token_type_ids = False,
            return_tensors = 'pt'
        )
        
        return {
            'attention_mask': encoding['attention_mask'],
            'ids': encoding['input_ids'],
            'targets': torch.tensor(self.targets[item],dtype=torch.float)
        }

In [None]:
class MyBertModel(nn.Module):
    
    def __init__(self):
        super(MyBertModel,self).__init__()
        self.model = BertModel.from_pretrained('bert-base-uncased',return_dict=False)
        self.dropout = nn.Dropout(0.3)
        self.linear = nn.Linear(768,1)
        
    def forward(self,input_ids,attention_mask):
        _,polled_output = self.model(input_ids = input_ids,attention_mask = attention_mask)
        output = self.dropout(polled_output)
        return self.linear(output)        

In [None]:
def train_function(data_loader,model,optimizer,device):
    
    model.train()
    losses = []
    for batch, data in enumerate(data_loader):

        dim = data['ids'].shape[0]
        ids = data['ids'].view(dim,-1).to(device)
        mask = data['attention_mask'].view(dim,-1).to(device)
        targets = data['targets'].view(dim,-1).to(device)

        optimizer.zero_grad()
        
        outputs = model(
                    input_ids = ids,
                    attention_mask = mask
                )
        
        loss = nn.BCEWithLogitsLoss()
        loss_output = loss(outputs,targets).to(device)
        losses.append(loss_output.item())
        loss_output.backward()
        optimizer.step()

    print('mean train loss', np.mean(losses))

In [None]:
def valid_function(data_loader,model,device):
    
    model.eval()
    
    val_targets = []
    val_outputs = []
    
    with torch.no_grad():

      for batch, data in enumerate(data_loader):

          dim = data['ids'].shape[0]
          ids = data['ids'].view(dim,-1).to(device)
          mask = data['attention_mask'].view(dim,-1).to(device)
          targets = data['targets'].view(dim,-1).to(device)

          outputs = model(
                      input_ids = ids,
                      attention_mask = mask
                  )
          
          val_outputs.extend(nn.Sigmoid()(outputs).cpu().detach().numpy().tolist())
          val_targets.extend(targets.cpu().detach().numpy().tolist())
    
    return val_targets, val_outputs

In [None]:
train_dataset = MyDataset(train.text.values,train.target.values,256)

train_data_loader = torch.utils.data.DataLoader(
                        dataset=train_dataset,
                        batch_size=128,
                        num_workers=1
                    )

valid_dataset = MyDataset(valid.text.values,valid.target.values,256)

valid_data_loader = torch.utils.data.DataLoader(
                        dataset=valid_dataset,
                        batch_size=128,
                        num_workers=1
                    )


In [None]:
model = MyBertModel()

In [None]:
for i,j in model.model.named_parameters():
  
  if re.search('encoder\.layer\.[0-9][^1]*\.',i):
    j.requires_grad = False

  if re.search('embeddings\.',i):
    j.requires_grad = False
  # j.requires_grad_ = False
  print(j.requires_grad)
  print(i)

In [None]:
optimizer = torch.optim.Adam(list(model.parameters()))

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
device

In [None]:
model = model.to(device)

max_f1 = 0

for epoch in range(5):
  
  train_function(train_data_loader,model,optimizer,device)
  targets,outputs = valid_function(valid_data_loader,model,device)

  outputs = np.array(outputs)>=0.5
  f1_score = sklearn.metrics.f1_score(targets,outputs,average="macro")
  print(f'for epoch {epoch} validation f1 score is {f1_score}')

  if max_f1<f1_score:
    max_f1 = f1_score
    torch.save(model.state_dict(), 'my_model')

In [None]:
max_f1

In [None]:
test = pd.read_csv('test.csv')

In [None]:
test.head()

In [None]:
test.text = test.text.str.lower()

In [None]:
test['text']=test['text'].apply(clean_data)
test['text']=test['text'].apply(convert_abbrev_in_text)

In [None]:
import collections
from collections import defaultdict

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
model = MyBertModel()

In [None]:
model.load_state_dict(torch.load('my_model'))

In [None]:
model = model.to(device)
model.eval()

In [None]:
l = defaultdict(list)

with torch.no_grad():

  for idx in test.index:
    
    data = tokenizer.encode_plus(
        str(test.loc[idx,'text']),
        max_length = 256,
        add_special_tokens = True,
        return_attention_mask = True,
        padding = 'max_length',
        truncation = True,
        return_token_type_ids = False,
        return_tensors = 'pt'
    )
    
    output = model(
        input_ids = data['input_ids'].to(device),
        attention_mask = data['attention_mask'].to(device)
    )

    l['id'].append(test.loc[idx,'id'])

    if nn.Sigmoid()(output)>=0.5:
      l['target'].append(1)
    else:
        l['target'].append(0)
    

In [None]:
df_sub = pd.DataFrame(l)

In [None]:
df_sub.to_csv('submission.csv',index=False)

In [None]:
pd.read_csv('submission.csv')

In [None]:
torch.cuda.empty_cache()


## Current Score: 0.82500 F1 score 

## Future Work:
1. Will add Roberta model.
2. Test the model with different thresholds.