In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from sklearn.metrics import  f1_score
from charformer_pytorch import GBST

In [None]:
!pip install charformer-pytorch

In [2]:
import random
import numpy as np
import torch 
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [3]:
def eval_conversation(text):
    input_ids = torch.tensor([list(text.encode("utf-8"))]) + 3 # add 3 for special tokens
    return input_ids


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
class MetNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.dim = 512
        self.tokenizer = GBST(
            num_tokens = 257,             # number of tokens, should be 256 for byte encoding (+ 1 special token for padding in this example)
            dim = 512,                    # dimension of token and intra-block positional embedding
            max_block_size = 4,           # maximum block size
            downsample_factor = 4,        # the final downsample factor by which the sequence length will decrease by
            score_consensus_attn = True   # whether to do the cheap score consensus (aka attention) as in eq. 5 in the paper
        ).to(device)
        self.positionEmbeddings = nn.Embedding(512 ,64)
        self.transformerLayer = nn.TransformerEncoderLayer(576,8) 
        self.linear1 = nn.Linear(576,  64)
        self.linear2 = nn.Linear(64,  1)
        self.linear3 = nn.Linear(512,  16)
        self.linear4 = nn.Linear(16,  1)

    def forward(self, x):
        x = x.type(torch.LongTensor).to(device)
        tokenize = self.tokenizer(x)
        positions = (torch.arange(0,self.dim ).reshape(1,self.dim ) + torch.zeros(x.shape[0],self.dim )).to(device)
        sentence = torch.cat((tokenize[0],self.positionEmbeddings(positions.long())),axis=2)
        attended = self.transformerLayer(sentence)
        linear1 = F.relu(self.linear1(attended))
        linear2 = F.relu(self.linear2(linear1))
        linear2 = linear2.view(-1,512) # reshaping the layer as the transformer outputs a 2d tensor (or 3d considering the batch size)
        linear3 = F.relu(self.linear3(linear2))
        out = torch.sigmoid(self.linear4(linear3))
        return out
    

In [6]:
model = MetNet().to(device)

In [8]:
path_to_json_file = '/scratch/sm9669/sarcasm_detection_shared_task_twitter_training.jsonl'
import json 

with open(path_to_json_file, 'r') as j:
     json_data = [json.loads(line) for line in j]

In [9]:
from tqdm import tqdm
accurate = 0
tot_tweets = 0

tweets=[]
labels=[]
for tweet in tqdm(json_data):
    tot_cont = ""
    for con in tweet['context']:
        tot_cont+=con 
    tot_cont+=tweet['response']
    tweets.append(tot_cont)
    labels.append(tweet['label'])

100%|██████████| 5000/5000 [00:00<00:00, 436025.53it/s]


In [10]:
import pandas as pd
label_list=[]
for i in range(0,len(labels)):
    if labels[i]=="SARCASM":
        label_list.append(1)
    else:
        label_list.append(0)

X_train =tweets
y_train = label_list

In [11]:
train_data=[]
mi=10000000000000000
ma=-1
for i in range(0,len(X_train)):
    train_data.append(eval_conversation(X_train[i]));
    mi=min(train_data[i].shape[1],mi)
    ma=max(train_data[i].shape[1],ma)

new_train_data=[]
for i in range(0,len(train_data)):
    if train_data[i].shape[1]<2048:
        z = torch.zeros(1, 2048-train_data[i].shape[1])
        s = torch.cat((train_data[i],z),1)
        new_train_data.append(s)
    else:
        cur_train_data = torch.reshape(train_data[i][0][:2048], (1, 2048)).type(torch.FloatTensor)
        new_train_data.append(torch.Tensor(cur_train_data))

new_train_data = torch.stack(new_train_data)
print(new_train_data.shape)

y_train=torch.Tensor(y_train)
print(y_train.shape)
new_train_data = torch.squeeze(new_train_data)


torch.Size([5000, 1, 2048])
torch.Size([5000])


In [12]:
import torch.optim as  optim

In [13]:
import torch.utils.data as data_utils
train=data_utils.TensorDataset(new_train_data ,y_train)


train_loader = torch.utils.data.DataLoader(train, batch_size=32,
               shuffle=True, num_workers=0, pin_memory=True)

In [14]:
path_to_json_file = '/scratch/sm9669/sarcasm_detection_shared_task_twitter_testing.jsonl'
import json 

with open(path_to_json_file, 'r') as j:
     json_data = [json.loads(line) for line in j]

In [15]:
from tqdm import tqdm

tweets=[]
labels=[]
for tweet in tqdm(json_data):
    tot_cont = ""
    for con in tweet['context']:
        tot_cont+=con 
    tot_cont+=tweet['response']
    tweets.append(tot_cont)
    labels.append(tweet['label'])

100%|██████████| 1800/1800 [00:00<00:00, 421867.86it/s]


In [16]:
import pandas as pd
test_label_list=[]
for i in range(0,len(labels)):
    if labels[i]=="SARCASM":
        test_label_list.append(1)
    else:
        test_label_list.append(0)

X_test =tweets
y_test = test_label_list

In [17]:
test_data=[]
mi=10000000000000000
ma=-1
for i in range(0,len(X_test)):
    test_data.append(eval_conversation(X_test[i]));
    mi=min(test_data[i].shape[1],mi)
    ma=max(test_data[i].shape[1],ma)

print(mi, ma)
new_test_data=[]
for i in range(0,len(test_data)):
    if test_data[i].shape[1]<2048:
        z = torch.zeros(1, 2048-test_data[i].shape[1])
        s = torch.cat((test_data[i],z),1)
        new_test_data.append(s)
    else:
        cur_test_data = torch.reshape(test_data[i][0][:2048], (1, 2048)).type(torch.FloatTensor)
        new_test_data.append(torch.Tensor(cur_test_data))

new_test_data = torch.stack(new_test_data)
print(new_test_data.shape)

y_test=torch.Tensor(y_test)
print(y_test.shape)
new_test_data = torch.squeeze(new_test_data)


117 4140
torch.Size([1800, 1, 2048])
torch.Size([1800])


In [18]:
import torch.utils.data as data_utils
test=data_utils.TensorDataset(new_test_data ,y_test)


test_loader = torch.utils.data.DataLoader(test, batch_size=32,
               shuffle=True, num_workers=0, pin_memory=True)

In [19]:
lr =  0.0045
optimizer = torch.optim.SGD(model.parameters(), lr,
                            momentum=0.9,
                            weight_decay=1e-4)

In [20]:
def calculateMetrics(ypred,ytrue):
    acc  = accuracy_score(ytrue,ypred)
    f1  = f1_score(ytrue,ypred)
    f1_average  = f1_score(ytrue,ypred,average="macro")
    return " f1 score: "+str(round(f1,3))+" f1 average: "+str(round(f1_average,3))+" accuracy: "+str(round(acc,3))

In [21]:
def train(train_loader, model, optimizer, epoch):

    model.train()
    trainpreds = torch.tensor([])
    traintrues = torch.tensor([])
    testpreds = torch.tensor([])
    testtrues = torch.tensor([])
    for i, (input, target) in enumerate(train_loader):
        input_var = input.cuda()
        target_var = target.cuda()
        optimizer.zero_grad()
        output = model(input_var)
        output = torch.squeeze(output, 1)
        trainpreds = torch.cat((trainpreds,output.cpu().detach()))
        traintrues = torch.cat((traintrues,target_var.cpu().detach()))
        
        loss = F.binary_cross_entropy(output,target_var)
        loss.backward()
        optimizer.step()
        
    err = F.binary_cross_entropy(trainpreds,traintrues)    
    print("train BCE loss: ",err.item(),calculateMetrics(torch.round(trainpreds).numpy(),traintrues.numpy()))
    
    model.eval()
    for i, (input, target) in enumerate(test_loader):
        input_var = input.cuda()
        target_var = target.cuda()
        output = model(input_var)
        output = torch.squeeze(output, 1)

        testpreds = torch.cat((testpreds,output.cpu().detach()))
        testtrues = torch.cat((testtrues,target_var.cpu().detach()))

        loss = F.binary_cross_entropy(output,target_var)

    print("test BCE loss: ",calculateMetrics(torch.round(testpreds).numpy(),testtrues.numpy()))


In [22]:
from tqdm import tqdm
for epoch in tqdm(range(0, 30)):
    lr = lr * (0.1 ** (epoch // 30))
    train(train_loader, model, optimizer, epoch)

  0%|          | 0/30 [00:00<?, ?it/s]

train BCE loss:  0.6935937404632568  f1 score: 0.373 f1 average: 0.485 accuracy: 0.509


  3%|▎         | 1/30 [00:21<10:22, 21.48s/it]

test BCE loss:   f1 score: 0.667 f1 average: 0.333 accuracy: 0.5
train BCE loss:  0.6913658976554871  f1 score: 0.446 f1 average: 0.518 accuracy: 0.529


  7%|▋         | 2/30 [00:42<10:01, 21.49s/it]

test BCE loss:   f1 score: 0.54 f1 average: 0.501 accuracy: 0.504
train BCE loss:  0.6866149306297302  f1 score: 0.536 f1 average: 0.552 accuracy: 0.552


 10%|█         | 3/30 [01:04<09:40, 21.51s/it]

test BCE loss:   f1 score: 0.41 f1 average: 0.491 accuracy: 0.503
train BCE loss:  0.6839903593063354  f1 score: 0.52 f1 average: 0.56 accuracy: 0.564


 13%|█▎        | 4/30 [01:26<09:20, 21.55s/it]

test BCE loss:   f1 score: 0.414 f1 average: 0.492 accuracy: 0.504
train BCE loss:  0.6842751502990723  f1 score: 0.513 f1 average: 0.559 accuracy: 0.564


 17%|█▋        | 5/30 [01:47<08:59, 21.59s/it]

test BCE loss:   f1 score: 0.587 f1 average: 0.504 accuracy: 0.518
train BCE loss:  0.6843687891960144  f1 score: 0.572 f1 average: 0.559 accuracy: 0.559


 20%|██        | 6/30 [02:09<08:39, 21.63s/it]

test BCE loss:   f1 score: 0.398 f1 average: 0.486 accuracy: 0.502
train BCE loss:  0.6827521920204163  f1 score: 0.523 f1 average: 0.56 accuracy: 0.563


 23%|██▎       | 7/30 [02:31<08:18, 21.66s/it]

test BCE loss:   f1 score: 0.444 f1 average: 0.5 accuracy: 0.507
train BCE loss:  0.6825318336486816  f1 score: 0.523 f1 average: 0.563 accuracy: 0.566


 27%|██▋       | 8/30 [02:52<07:56, 21.67s/it]

test BCE loss:   f1 score: 0.528 f1 average: 0.531 accuracy: 0.531
train BCE loss:  0.6802144646644592  f1 score: 0.548 f1 average: 0.568 accuracy: 0.569


 30%|███       | 9/30 [03:14<07:35, 21.69s/it]

test BCE loss:   f1 score: 0.315 f1 average: 0.459 accuracy: 0.498
train BCE loss:  0.6780626773834229  f1 score: 0.541 f1 average: 0.57 accuracy: 0.572


 33%|███▎      | 10/30 [03:36<07:13, 21.70s/it]

test BCE loss:   f1 score: 0.394 f1 average: 0.492 accuracy: 0.511
train BCE loss:  0.6728733777999878  f1 score: 0.581 f1 average: 0.585 accuracy: 0.585


 37%|███▋      | 11/30 [03:58<06:52, 21.70s/it]

test BCE loss:   f1 score: 0.302 f1 average: 0.462 accuracy: 0.509
train BCE loss:  0.6709381937980652  f1 score: 0.574 f1 average: 0.586 accuracy: 0.586


 40%|████      | 12/30 [04:19<06:30, 21.70s/it]

test BCE loss:   f1 score: 0.648 f1 average: 0.442 accuracy: 0.518
train BCE loss:  0.6674875020980835  f1 score: 0.605 f1 average: 0.601 accuracy: 0.601


 43%|████▎     | 13/30 [04:41<06:08, 21.70s/it]

test BCE loss:   f1 score: 0.62 f1 average: 0.521 accuracy: 0.542
train BCE loss:  0.6512551307678223  f1 score: 0.625 f1 average: 0.62 accuracy: 0.62


 47%|████▋     | 14/30 [05:03<05:47, 21.70s/it]

test BCE loss:   f1 score: 0.487 f1 average: 0.529 accuracy: 0.532
train BCE loss:  0.6414102911949158  f1 score: 0.636 f1 average: 0.635 accuracy: 0.635


 50%|█████     | 15/30 [05:24<05:25, 21.71s/it]

test BCE loss:   f1 score: 0.533 f1 average: 0.549 accuracy: 0.55
train BCE loss:  0.642330527305603  f1 score: 0.642 f1 average: 0.634 accuracy: 0.635


 53%|█████▎    | 16/30 [05:46<05:03, 21.70s/it]

test BCE loss:   f1 score: 0.478 f1 average: 0.53 accuracy: 0.536
train BCE loss:  0.615013837814331  f1 score: 0.671 f1 average: 0.667 accuracy: 0.667


 57%|█████▋    | 17/30 [06:08<04:42, 21.70s/it]

test BCE loss:   f1 score: 0.488 f1 average: 0.529 accuracy: 0.533
train BCE loss:  0.5900360345840454  f1 score: 0.696 f1 average: 0.689 accuracy: 0.689


 60%|██████    | 18/30 [06:29<04:20, 21.70s/it]

test BCE loss:   f1 score: 0.33 f1 average: 0.483 accuracy: 0.528
train BCE loss:  0.5882797837257385  f1 score: 0.689 f1 average: 0.683 accuracy: 0.684


 63%|██████▎   | 19/30 [06:51<03:58, 21.70s/it]

test BCE loss:   f1 score: 0.578 f1 average: 0.569 accuracy: 0.569
train BCE loss:  0.5555251836776733  f1 score: 0.718 f1 average: 0.713 accuracy: 0.713


 67%|██████▋   | 20/30 [07:13<03:36, 21.70s/it]

test BCE loss:   f1 score: 0.607 f1 average: 0.57 accuracy: 0.573
train BCE loss:  0.5361648201942444  f1 score: 0.735 f1 average: 0.732 accuracy: 0.732


 70%|███████   | 21/30 [07:35<03:15, 21.71s/it]

test BCE loss:   f1 score: 0.568 f1 average: 0.559 accuracy: 0.559
train BCE loss:  0.5149388909339905  f1 score: 0.743 f1 average: 0.739 accuracy: 0.739


 73%|███████▎  | 22/30 [07:56<02:53, 21.70s/it]

test BCE loss:   f1 score: 0.663 f1 average: 0.461 accuracy: 0.537
train BCE loss:  0.5085058808326721  f1 score: 0.746 f1 average: 0.742 accuracy: 0.742


 77%|███████▋  | 23/30 [08:18<02:31, 21.70s/it]

test BCE loss:   f1 score: 0.646 f1 average: 0.531 accuracy: 0.559
train BCE loss:  0.48672983050346375  f1 score: 0.766 f1 average: 0.761 accuracy: 0.761


 80%|████████  | 24/30 [08:40<02:10, 21.70s/it]

test BCE loss:   f1 score: 0.578 f1 average: 0.565 accuracy: 0.566
train BCE loss:  0.4662865102291107  f1 score: 0.778 f1 average: 0.772 accuracy: 0.772


 83%|████████▎ | 25/30 [09:01<01:48, 21.69s/it]

test BCE loss:   f1 score: 0.562 f1 average: 0.57 accuracy: 0.571
train BCE loss:  0.4398786723613739  f1 score: 0.793 f1 average: 0.791 accuracy: 0.791


 87%|████████▋ | 26/30 [09:23<01:26, 21.70s/it]

test BCE loss:   f1 score: 0.623 f1 average: 0.537 accuracy: 0.553
train BCE loss:  0.39745813608169556  f1 score: 0.822 f1 average: 0.82 accuracy: 0.82


 90%|█████████ | 27/30 [09:45<01:05, 21.69s/it]

test BCE loss:   f1 score: 0.632 f1 average: 0.556 accuracy: 0.569
train BCE loss:  0.4033062756061554  f1 score: 0.813 f1 average: 0.811 accuracy: 0.811


 93%|█████████▎| 28/30 [10:06<00:43, 21.70s/it]

test BCE loss:   f1 score: 0.539 f1 average: 0.564 accuracy: 0.566
train BCE loss:  0.3838340640068054  f1 score: 0.83 f1 average: 0.828 accuracy: 0.828


 97%|█████████▋| 29/30 [10:28<00:21, 21.69s/it]

test BCE loss:   f1 score: 0.589 f1 average: 0.573 accuracy: 0.574
train BCE loss:  0.3666486442089081  f1 score: 0.836 f1 average: 0.835 accuracy: 0.835


100%|██████████| 30/30 [10:50<00:00, 21.68s/it]

test BCE loss:   f1 score: 0.52 f1 average: 0.564 accuracy: 0.568





In [23]:
torch.save(model.state_dict(), "charformer-twitter-30.pt")