In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from sklearn.metrics import  f1_score
from charformer_pytorch import GBST

torch.Size([1, 232, 512])

In [None]:
!pip install charformer-pytorch

In [2]:
import random
import numpy as np
import torch 
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [3]:
def eval_conversation(text):
    input_ids = torch.tensor([list(text.encode("utf-8"))]) + 3 # add 3 for special tokens
    return input_ids


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
class MetNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.dim = 232
        self.tokenizer = GBST(
            num_tokens = 257,             # number of tokens, should be 256 for byte encoding (+ 1 special token for padding in this example)
            dim = 512,                    # dimension of token and intra-block positional embedding
            max_block_size = 4,           # maximum block size
            downsample_factor = 4,        # the final downsample factor by which the sequence length will decrease by
            score_consensus_attn = True   # whether to do the cheap score consensus (aka attention) as in eq. 5 in the paper
        ).to(device)
        self.positionEmbeddings = nn.Embedding(512 ,64)
        self.transformerLayer = nn.TransformerEncoderLayer(576,8) 
        self.linear1 = nn.Linear(576,  64)
        self.linear2 = nn.Linear(64,  1)
        self.linear3 = nn.Linear(232,  16)
        self.linear4 = nn.Linear(16,  1)

    def forward(self, x):
        x = x.type(torch.LongTensor).to(device)
        tokenize = self.tokenizer(x)
        positions = (torch.arange(0,self.dim ).reshape(1,self.dim ) + torch.zeros(x.shape[0],self.dim )).to(device)
        sentence = torch.cat((tokenize[0],self.positionEmbeddings(positions.long())),axis=2)
        attended = self.transformerLayer(sentence)
        linear1 = F.relu(self.linear1(attended))
        linear2 = F.relu(self.linear2(linear1))
        linear2 = linear2.view(-1,232) # reshaping the layer as the transformer outputs a 2d tensor (or 3d considering the batch size)
        linear3 = F.relu(self.linear3(linear2))
        out = torch.sigmoid(self.linear4(linear3))
        return out
    

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MetNet().to(device)
tokens = torch.randint(0, 257, (1, 927))# uneven number of tokens (1023)
print(tokens.shape)
print(model(tokens.to(device)).shape)

torch.Size([1, 927])
torch.Size([1, 1])


In [7]:
import numpy as np
import pandas as pd
train=pd.read_json('Sarcasm_Headlines_Dataset_v2.json',lines=True)
train.head()

Unnamed: 0,is_sarcastic,headline,article_link
0,1,thirtysomething scientists unveil doomsday clo...,https://www.theonion.com/thirtysomething-scien...
1,0,dem rep. totally nails why congress is falling...,https://www.huffingtonpost.com/entry/donna-edw...
2,0,eat your veggies: 9 deliciously different recipes,https://www.huffingtonpost.com/entry/eat-your-...
3,1,inclement weather prevents liar from getting t...,https://local.theonion.com/inclement-weather-p...
4,1,mother comes pretty close to using word 'strea...,https://www.theonion.com/mother-comes-pretty-c...


In [8]:
train=train.drop(['article_link'],axis=1)
train.head()

Unnamed: 0,is_sarcastic,headline
0,1,thirtysomething scientists unveil doomsday clo...
1,0,dem rep. totally nails why congress is falling...
2,0,eat your veggies: 9 deliciously different recipes
3,1,inclement weather prevents liar from getting t...
4,1,mother comes pretty close to using word 'strea...


In [9]:
import random

train = train.sample(frac = 1)
train_data = train[:17171]
test_data = train[17171:22895]
val_data = train[22895:]



In [10]:
X_train=list(train_data['headline'])

y_train=list(train_data['is_sarcastic'])
len(X_train)

17171

In [11]:
X_val=list(val_data['headline'])

y_val=list(val_data['is_sarcastic'])
len(X_val)

5724

In [12]:
X_test=list(test_data['headline'])

y_test=list(test_data['is_sarcastic'])


In [13]:
train_data=[]
mi=10000000000000000
ma=-1
for i in range(0,len(X_train)):
    train_data.append(eval_conversation(X_train[i]));
    mi=min(train_data[i].shape[1],mi)
    ma=max(train_data[i].shape[1],ma)

new_train_data=[]

for i in range(0,len(train_data)):
    z = torch.zeros(1, 926-train_data[i].shape[1])
    s = torch.cat((train_data[i],z),1)
    new_train_data.append(s)

new_train_data = torch.stack(new_train_data)

y_train=torch.Tensor(y_train)

new_train_data = torch.squeeze(new_train_data)
print(new_train_data.shape)


torch.Size([17171, 926])


In [14]:
val_data=[]
mi=10000000000000000
ma=-1
for i in range(0,len(X_val)):
    val_data.append(eval_conversation(X_val[i]));
    mi=min(val_data[i].shape[1],mi)
    ma=max(val_data[i].shape[1],ma)

new_val_data=[]

for i in range(0,len(val_data)):
    z = torch.zeros(1, 926-val_data[i].shape[1])
    s = torch.cat((val_data[i],z),1)
    new_val_data.append(s)

new_val_data = torch.stack(new_val_data)

y_val=torch.Tensor(y_val)

new_val_data = torch.squeeze(new_val_data)
print(new_val_data.shape)

torch.Size([5724, 926])


In [15]:
import torch.optim as  optim

In [16]:
import torch.utils.data as data_utils
train=data_utils.TensorDataset(new_train_data ,y_train)


train_loader = torch.utils.data.DataLoader(train, batch_size=32,
               shuffle=True, num_workers=0, pin_memory=True)

In [17]:
import torch.utils.data as data_utils
val=data_utils.TensorDataset(new_val_data ,y_val)


val_loader = torch.utils.data.DataLoader(val, batch_size=32,
               shuffle=True, num_workers=0, pin_memory=True)

In [18]:
lr =  0.0045
optimizer = torch.optim.SGD(model.parameters(), lr,
                            momentum=0.9,
                            weight_decay=1e-4)

In [19]:
def calculateMetrics(ypred,ytrue):
    acc  = accuracy_score(ytrue,ypred)
    f1  = f1_score(ytrue,ypred)
    f1_average  = f1_score(ytrue,ypred,average="macro")
    return " f1 score: "+str(round(f1,3))+" f1 average: "+str(round(f1_average,3))+" accuracy: "+str(round(acc,3))

In [20]:
def train(train_loader, model, optimizer, epoch):

    model.train()
    trainpreds = torch.tensor([])
    traintrues = torch.tensor([])
    valpreds = torch.tensor([])
    valtrues = torch.tensor([])
    for i, (input, target) in enumerate(train_loader):
        input_var = input.cuda()
        target_var = target.cuda()
        optimizer.zero_grad()
        output = model(input_var)
        output = torch.squeeze(output, 1)
        trainpreds = torch.cat((trainpreds,output.cpu().detach()))
        traintrues = torch.cat((traintrues,target_var.cpu().detach()))
        
        loss = F.binary_cross_entropy(output,target_var)
        loss.backward()
        optimizer.step()
        
    err = F.binary_cross_entropy(trainpreds,traintrues)    
    print("train BCE loss: ",err.item(),calculateMetrics(torch.round(trainpreds).numpy(),traintrues.numpy()))
    
    model.eval()
    for i, (input, target) in enumerate(val_loader):
        input_var = input.cuda()
        target_var = target.cuda()
        optimizer.zero_grad()
        output = model(input_var)
        output = torch.squeeze(output, 1)

        valpreds = torch.cat((valpreds,output.cpu().detach()))
        valtrues = torch.cat((valtrues,target_var.cpu().detach()))

        loss = F.binary_cross_entropy(output,target_var)

    print("test BCE loss: ",calculateMetrics(torch.round(valpreds).numpy(),valtrues.numpy()))



In [21]:
from tqdm import tqdm
for epoch in tqdm(range(0, 100)):
    lr = lr * (0.1 ** (epoch // 100))
    train(train_loader, model, optimizer, epoch)

  0%|          | 0/100 [00:00<?, ?it/s]

train BCE loss:  0.6919268369674683  f1 score: 0.0 f1 average: 0.345 accuracy: 0.526


  1%|          | 1/100 [00:30<50:46, 30.77s/it]

test BCE loss:   f1 score: 0.0 f1 average: 0.343 accuracy: 0.521
train BCE loss:  0.6914328336715698  f1 score: 0.0 f1 average: 0.345 accuracy: 0.526


  2%|▏         | 2/100 [01:01<50:15, 30.77s/it]

test BCE loss:   f1 score: 0.0 f1 average: 0.343 accuracy: 0.521
train BCE loss:  0.6865565776824951  f1 score: 0.271 f1 average: 0.472 accuracy: 0.549


  3%|▎         | 3/100 [01:32<49:43, 30.76s/it]

test BCE loss:   f1 score: 0.407 f1 average: 0.542 accuracy: 0.582
train BCE loss:  0.6545202136039734  f1 score: 0.453 f1 average: 0.567 accuracy: 0.597


  4%|▍         | 4/100 [02:03<49:12, 30.76s/it]

test BCE loss:   f1 score: 0.579 f1 average: 0.63 accuracy: 0.636
train BCE loss:  0.6082344651222229  f1 score: 0.631 f1 average: 0.639 accuracy: 0.639


  5%|▌         | 5/100 [02:33<48:43, 30.77s/it]

test BCE loss:   f1 score: 0.679 f1 average: 0.646 accuracy: 0.649
train BCE loss:  0.5897772908210754  f1 score: 0.656 f1 average: 0.658 accuracy: 0.658


  6%|▌         | 6/100 [03:04<48:14, 30.79s/it]

test BCE loss:   f1 score: 0.613 f1 average: 0.656 accuracy: 0.662
train BCE loss:  0.5722116231918335  f1 score: 0.676 f1 average: 0.677 accuracy: 0.677


  7%|▋         | 7/100 [03:35<47:45, 30.82s/it]

test BCE loss:   f1 score: 0.51 f1 average: 0.613 accuracy: 0.64
train BCE loss:  0.5609079599380493  f1 score: 0.691 f1 average: 0.691 accuracy: 0.691


  8%|▊         | 8/100 [04:06<47:17, 30.84s/it]

test BCE loss:   f1 score: 0.714 f1 average: 0.695 accuracy: 0.697
train BCE loss:  0.547664225101471  f1 score: 0.704 f1 average: 0.706 accuracy: 0.706


  9%|▉         | 9/100 [04:37<46:47, 30.85s/it]

test BCE loss:   f1 score: 0.656 f1 average: 0.688 accuracy: 0.692
train BCE loss:  0.5404995679855347  f1 score: 0.709 f1 average: 0.712 accuracy: 0.712


 10%|█         | 10/100 [05:08<46:18, 30.87s/it]

test BCE loss:   f1 score: 0.731 f1 average: 0.701 accuracy: 0.704
train BCE loss:  0.5218433737754822  f1 score: 0.724 f1 average: 0.729 accuracy: 0.729


 11%|█         | 11/100 [05:39<45:48, 30.88s/it]

test BCE loss:   f1 score: 0.716 f1 average: 0.72 accuracy: 0.72
train BCE loss:  0.5174424052238464  f1 score: 0.726 f1 average: 0.732 accuracy: 0.732


 12%|█▏        | 12/100 [06:10<45:18, 30.89s/it]

test BCE loss:   f1 score: 0.707 f1 average: 0.715 accuracy: 0.716
train BCE loss:  0.5051499009132385  f1 score: 0.736 f1 average: 0.743 accuracy: 0.743


 13%|█▎        | 13/100 [06:40<44:47, 30.89s/it]

test BCE loss:   f1 score: 0.74 f1 average: 0.727 accuracy: 0.728
train BCE loss:  0.4927729368209839  f1 score: 0.743 f1 average: 0.75 accuracy: 0.751


 14%|█▍        | 14/100 [07:11<44:17, 30.90s/it]

test BCE loss:   f1 score: 0.737 f1 average: 0.732 accuracy: 0.732
train BCE loss:  0.48511695861816406  f1 score: 0.749 f1 average: 0.756 accuracy: 0.757


 15%|█▌        | 15/100 [07:42<43:47, 30.91s/it]

test BCE loss:   f1 score: 0.729 f1 average: 0.738 accuracy: 0.738
train BCE loss:  0.47499415278434753  f1 score: 0.757 f1 average: 0.766 accuracy: 0.766


 16%|█▌        | 16/100 [08:13<43:16, 30.91s/it]

test BCE loss:   f1 score: 0.743 f1 average: 0.734 accuracy: 0.735
train BCE loss:  0.46832993626594543  f1 score: 0.767 f1 average: 0.773 accuracy: 0.773


 17%|█▋        | 17/100 [08:44<42:45, 30.91s/it]

test BCE loss:   f1 score: 0.757 f1 average: 0.74 accuracy: 0.741
train BCE loss:  0.4577774107456207  f1 score: 0.772 f1 average: 0.779 accuracy: 0.78


 18%|█▊        | 18/100 [09:15<42:15, 30.92s/it]

test BCE loss:   f1 score: 0.741 f1 average: 0.749 accuracy: 0.749
train BCE loss:  0.45113661885261536  f1 score: 0.774 f1 average: 0.781 accuracy: 0.781


 19%|█▉        | 19/100 [09:46<41:43, 30.91s/it]

test BCE loss:   f1 score: 0.728 f1 average: 0.746 accuracy: 0.747
train BCE loss:  0.43930426239967346  f1 score: 0.783 f1 average: 0.791 accuracy: 0.791


 20%|██        | 20/100 [10:17<41:12, 30.91s/it]

test BCE loss:   f1 score: 0.6 f1 average: 0.677 accuracy: 0.695
train BCE loss:  0.43249499797821045  f1 score: 0.787 f1 average: 0.794 accuracy: 0.794


 21%|██        | 21/100 [10:48<40:41, 30.91s/it]

test BCE loss:   f1 score: 0.748 f1 average: 0.748 accuracy: 0.748
train BCE loss:  0.4239487648010254  f1 score: 0.792 f1 average: 0.799 accuracy: 0.799


 22%|██▏       | 22/100 [11:19<40:11, 30.91s/it]

test BCE loss:   f1 score: 0.744 f1 average: 0.747 accuracy: 0.747
train BCE loss:  0.4178540110588074  f1 score: 0.797 f1 average: 0.804 accuracy: 0.804


 23%|██▎       | 23/100 [11:50<39:39, 30.91s/it]

test BCE loss:   f1 score: 0.747 f1 average: 0.764 accuracy: 0.765
train BCE loss:  0.40949589014053345  f1 score: 0.801 f1 average: 0.808 accuracy: 0.808


 24%|██▍       | 24/100 [12:20<39:08, 30.91s/it]

test BCE loss:   f1 score: 0.729 f1 average: 0.751 accuracy: 0.752
train BCE loss:  0.40144971013069153  f1 score: 0.802 f1 average: 0.809 accuracy: 0.809


 25%|██▌       | 25/100 [12:51<38:38, 30.91s/it]

test BCE loss:   f1 score: 0.733 f1 average: 0.754 accuracy: 0.756
train BCE loss:  0.3925043046474457  f1 score: 0.812 f1 average: 0.82 accuracy: 0.82


 26%|██▌       | 26/100 [13:22<38:07, 30.91s/it]

test BCE loss:   f1 score: 0.728 f1 average: 0.754 accuracy: 0.756
train BCE loss:  0.3796336054801941  f1 score: 0.818 f1 average: 0.824 accuracy: 0.825


 27%|██▋       | 27/100 [13:53<37:37, 30.92s/it]

test BCE loss:   f1 score: 0.758 f1 average: 0.768 accuracy: 0.768
train BCE loss:  0.37502217292785645  f1 score: 0.825 f1 average: 0.832 accuracy: 0.832


 28%|██▊       | 28/100 [14:24<37:05, 30.91s/it]

test BCE loss:   f1 score: 0.77 f1 average: 0.764 accuracy: 0.764
train BCE loss:  0.36976075172424316  f1 score: 0.827 f1 average: 0.834 accuracy: 0.834


 29%|██▉       | 29/100 [14:55<36:35, 30.93s/it]

test BCE loss:   f1 score: 0.728 f1 average: 0.751 accuracy: 0.753
train BCE loss:  0.3581070899963379  f1 score: 0.834 f1 average: 0.841 accuracy: 0.841


 30%|███       | 30/100 [15:26<36:04, 30.92s/it]

test BCE loss:   f1 score: 0.755 f1 average: 0.768 accuracy: 0.769
train BCE loss:  0.3457251787185669  f1 score: 0.842 f1 average: 0.848 accuracy: 0.848


 31%|███       | 31/100 [15:57<35:33, 30.92s/it]

test BCE loss:   f1 score: 0.763 f1 average: 0.772 accuracy: 0.772
train BCE loss:  0.33940955996513367  f1 score: 0.841 f1 average: 0.848 accuracy: 0.848


 32%|███▏      | 32/100 [16:28<35:02, 30.92s/it]

test BCE loss:   f1 score: 0.777 f1 average: 0.774 accuracy: 0.774
train BCE loss:  0.3268377184867859  f1 score: 0.85 f1 average: 0.857 accuracy: 0.857


 33%|███▎      | 33/100 [16:59<34:32, 30.93s/it]

test BCE loss:   f1 score: 0.755 f1 average: 0.757 accuracy: 0.757
train BCE loss:  0.3198870122432709  f1 score: 0.855 f1 average: 0.861 accuracy: 0.861


 34%|███▍      | 34/100 [17:30<34:00, 30.91s/it]

test BCE loss:   f1 score: 0.753 f1 average: 0.768 accuracy: 0.769
train BCE loss:  0.30223312973976135  f1 score: 0.864 f1 average: 0.87 accuracy: 0.871


 35%|███▌      | 35/100 [18:01<33:29, 30.91s/it]

test BCE loss:   f1 score: 0.77 f1 average: 0.768 accuracy: 0.768
train BCE loss:  0.29709017276763916  f1 score: 0.867 f1 average: 0.873 accuracy: 0.874


 36%|███▌      | 36/100 [18:31<32:57, 30.90s/it]

test BCE loss:   f1 score: 0.77 f1 average: 0.772 accuracy: 0.773
train BCE loss:  0.28570619225502014  f1 score: 0.872 f1 average: 0.878 accuracy: 0.878


 37%|███▋      | 37/100 [19:02<32:27, 30.91s/it]

test BCE loss:   f1 score: 0.736 f1 average: 0.763 accuracy: 0.766
train BCE loss:  0.27068382501602173  f1 score: 0.88 f1 average: 0.886 accuracy: 0.886


 38%|███▊      | 38/100 [19:33<31:56, 30.90s/it]

test BCE loss:   f1 score: 0.751 f1 average: 0.766 accuracy: 0.767
train BCE loss:  0.2634357213973999  f1 score: 0.882 f1 average: 0.888 accuracy: 0.888


 39%|███▉      | 39/100 [20:04<31:24, 30.89s/it]

test BCE loss:   f1 score: 0.767 f1 average: 0.767 accuracy: 0.767
train BCE loss:  0.2552218437194824  f1 score: 0.887 f1 average: 0.892 accuracy: 0.892


 40%|████      | 40/100 [20:35<30:53, 30.89s/it]

test BCE loss:   f1 score: 0.715 f1 average: 0.746 accuracy: 0.75
train BCE loss:  0.24154654145240784  f1 score: 0.892 f1 average: 0.897 accuracy: 0.897


 41%|████      | 41/100 [21:06<30:23, 30.90s/it]

test BCE loss:   f1 score: 0.746 f1 average: 0.764 accuracy: 0.766
train BCE loss:  0.22883950173854828  f1 score: 0.9 f1 average: 0.905 accuracy: 0.905


 42%|████▏     | 42/100 [21:37<29:51, 30.90s/it]

test BCE loss:   f1 score: 0.775 f1 average: 0.772 accuracy: 0.772
train BCE loss:  0.21802768111228943  f1 score: 0.906 f1 average: 0.91 accuracy: 0.91


 43%|████▎     | 43/100 [22:08<29:20, 30.89s/it]

test BCE loss:   f1 score: 0.746 f1 average: 0.764 accuracy: 0.765
train BCE loss:  0.21389111876487732  f1 score: 0.907 f1 average: 0.911 accuracy: 0.912


 44%|████▍     | 44/100 [22:39<28:49, 30.89s/it]

test BCE loss:   f1 score: 0.745 f1 average: 0.763 accuracy: 0.764
train BCE loss:  0.1990872621536255  f1 score: 0.916 f1 average: 0.92 accuracy: 0.92


 45%|████▌     | 45/100 [23:09<28:19, 30.90s/it]

test BCE loss:   f1 score: 0.751 f1 average: 0.762 accuracy: 0.763
train BCE loss:  0.1909104436635971  f1 score: 0.916 f1 average: 0.92 accuracy: 0.921


 46%|████▌     | 46/100 [23:40<27:48, 30.90s/it]

test BCE loss:   f1 score: 0.77 f1 average: 0.77 accuracy: 0.77
train BCE loss:  0.17605476081371307  f1 score: 0.925 f1 average: 0.929 accuracy: 0.929


 47%|████▋     | 47/100 [24:11<27:17, 30.89s/it]

test BCE loss:   f1 score: 0.756 f1 average: 0.761 accuracy: 0.761
train BCE loss:  0.16544754803180695  f1 score: 0.93 f1 average: 0.933 accuracy: 0.933


 48%|████▊     | 48/100 [24:42<26:46, 30.89s/it]

test BCE loss:   f1 score: 0.768 f1 average: 0.776 accuracy: 0.776
train BCE loss:  0.1605014055967331  f1 score: 0.932 f1 average: 0.935 accuracy: 0.935


 49%|████▉     | 49/100 [25:13<26:16, 30.91s/it]

test BCE loss:   f1 score: 0.752 f1 average: 0.765 accuracy: 0.765
train BCE loss:  0.15906864404678345  f1 score: 0.934 f1 average: 0.938 accuracy: 0.938


 50%|█████     | 50/100 [25:44<25:44, 30.90s/it]

test BCE loss:   f1 score: 0.746 f1 average: 0.767 accuracy: 0.769
train BCE loss:  0.14289015531539917  f1 score: 0.938 f1 average: 0.941 accuracy: 0.942


 51%|█████     | 51/100 [26:15<25:14, 30.90s/it]

test BCE loss:   f1 score: 0.728 f1 average: 0.754 accuracy: 0.756
train BCE loss:  0.13780473172664642  f1 score: 0.943 f1 average: 0.946 accuracy: 0.946


 52%|█████▏    | 52/100 [26:46<24:42, 30.89s/it]

test BCE loss:   f1 score: 0.743 f1 average: 0.766 accuracy: 0.769
train BCE loss:  0.12917183339595795  f1 score: 0.944 f1 average: 0.947 accuracy: 0.947


 53%|█████▎    | 53/100 [27:17<24:11, 30.89s/it]

test BCE loss:   f1 score: 0.766 f1 average: 0.775 accuracy: 0.775
train BCE loss:  0.11477455496788025  f1 score: 0.953 f1 average: 0.955 accuracy: 0.955


 54%|█████▍    | 54/100 [27:48<23:40, 30.89s/it]

test BCE loss:   f1 score: 0.738 f1 average: 0.757 accuracy: 0.759
train BCE loss:  0.12148213386535645  f1 score: 0.949 f1 average: 0.952 accuracy: 0.952


 55%|█████▌    | 55/100 [28:18<23:09, 30.89s/it]

test BCE loss:   f1 score: 0.765 f1 average: 0.769 accuracy: 0.77
train BCE loss:  0.10150983929634094  f1 score: 0.959 f1 average: 0.961 accuracy: 0.961


 56%|█████▌    | 56/100 [28:50<22:42, 30.96s/it]

test BCE loss:   f1 score: 0.766 f1 average: 0.769 accuracy: 0.769
train BCE loss:  0.10728916525840759  f1 score: 0.957 f1 average: 0.959 accuracy: 0.959


 57%|█████▋    | 57/100 [29:20<22:10, 30.94s/it]

test BCE loss:   f1 score: 0.768 f1 average: 0.776 accuracy: 0.777
train BCE loss:  0.10325795412063599  f1 score: 0.958 f1 average: 0.96 accuracy: 0.96


 58%|█████▊    | 58/100 [29:51<21:38, 30.92s/it]

test BCE loss:   f1 score: 0.762 f1 average: 0.767 accuracy: 0.767
train BCE loss:  0.08996625244617462  f1 score: 0.963 f1 average: 0.965 accuracy: 0.965


 59%|█████▉    | 59/100 [30:22<21:07, 30.91s/it]

test BCE loss:   f1 score: 0.754 f1 average: 0.77 accuracy: 0.771
train BCE loss:  0.09082117676734924  f1 score: 0.963 f1 average: 0.965 accuracy: 0.965


 60%|██████    | 60/100 [30:53<20:38, 30.97s/it]

test BCE loss:   f1 score: 0.762 f1 average: 0.768 accuracy: 0.769
train BCE loss:  0.08178670704364777  f1 score: 0.967 f1 average: 0.969 accuracy: 0.969


 61%|██████    | 61/100 [31:24<20:06, 30.95s/it]

test BCE loss:   f1 score: 0.757 f1 average: 0.764 accuracy: 0.764
train BCE loss:  0.09331148862838745  f1 score: 0.963 f1 average: 0.965 accuracy: 0.965


 62%|██████▏   | 62/100 [31:55<19:35, 30.93s/it]

test BCE loss:   f1 score: 0.761 f1 average: 0.77 accuracy: 0.77
train BCE loss:  0.07260821014642715  f1 score: 0.971 f1 average: 0.973 accuracy: 0.973


 63%|██████▎   | 63/100 [32:26<19:03, 30.92s/it]

test BCE loss:   f1 score: 0.769 f1 average: 0.767 accuracy: 0.767
train BCE loss:  0.0824771374464035  f1 score: 0.966 f1 average: 0.968 accuracy: 0.968


 64%|██████▍   | 64/100 [32:57<18:33, 30.92s/it]

test BCE loss:   f1 score: 0.756 f1 average: 0.767 accuracy: 0.768
train BCE loss:  0.06973360478878021  f1 score: 0.973 f1 average: 0.974 accuracy: 0.974


 65%|██████▌   | 65/100 [33:28<18:01, 30.91s/it]

test BCE loss:   f1 score: 0.751 f1 average: 0.769 accuracy: 0.77
train BCE loss:  0.07006888091564178  f1 score: 0.973 f1 average: 0.974 accuracy: 0.974


 66%|██████▌   | 66/100 [33:59<17:30, 30.90s/it]

test BCE loss:   f1 score: 0.766 f1 average: 0.768 accuracy: 0.769
train BCE loss:  0.06343726068735123  f1 score: 0.976 f1 average: 0.977 accuracy: 0.977


 67%|██████▋   | 67/100 [34:30<16:59, 30.89s/it]

test BCE loss:   f1 score: 0.754 f1 average: 0.765 accuracy: 0.765
train BCE loss:  0.06328216940164566  f1 score: 0.975 f1 average: 0.976 accuracy: 0.976


 68%|██████▊   | 68/100 [35:00<16:28, 30.90s/it]

test BCE loss:   f1 score: 0.745 f1 average: 0.763 accuracy: 0.765
train BCE loss:  0.061685554683208466  f1 score: 0.976 f1 average: 0.977 accuracy: 0.977


 69%|██████▉   | 69/100 [35:31<15:57, 30.89s/it]

test BCE loss:   f1 score: 0.767 f1 average: 0.771 accuracy: 0.771
train BCE loss:  0.047573335468769073  f1 score: 0.981 f1 average: 0.982 accuracy: 0.982


 70%|███████   | 70/100 [36:02<15:26, 30.88s/it]

test BCE loss:   f1 score: 0.738 f1 average: 0.764 accuracy: 0.766
train BCE loss:  0.06049111858010292  f1 score: 0.977 f1 average: 0.978 accuracy: 0.978


 71%|███████   | 71/100 [36:33<14:55, 30.88s/it]

test BCE loss:   f1 score: 0.77 f1 average: 0.779 accuracy: 0.78
train BCE loss:  0.057192787528038025  f1 score: 0.977 f1 average: 0.978 accuracy: 0.978


 72%|███████▏  | 72/100 [37:04<14:25, 30.90s/it]

test BCE loss:   f1 score: 0.758 f1 average: 0.766 accuracy: 0.766
train BCE loss:  0.04355427995324135  f1 score: 0.983 f1 average: 0.984 accuracy: 0.984


 73%|███████▎  | 73/100 [37:35<13:54, 30.89s/it]

test BCE loss:   f1 score: 0.765 f1 average: 0.765 accuracy: 0.765
train BCE loss:  0.053268205374479294  f1 score: 0.981 f1 average: 0.982 accuracy: 0.982


 74%|███████▍  | 74/100 [38:06<13:22, 30.88s/it]

test BCE loss:   f1 score: 0.746 f1 average: 0.767 accuracy: 0.768
train BCE loss:  0.04725223034620285  f1 score: 0.981 f1 average: 0.982 accuracy: 0.982


 75%|███████▌  | 75/100 [38:37<12:51, 30.88s/it]

test BCE loss:   f1 score: 0.74 f1 average: 0.762 accuracy: 0.765
train BCE loss:  0.045321203768253326  f1 score: 0.982 f1 average: 0.983 accuracy: 0.983


 76%|███████▌  | 76/100 [39:08<12:21, 30.90s/it]

test BCE loss:   f1 score: 0.769 f1 average: 0.762 accuracy: 0.763
train BCE loss:  0.048438988626003265  f1 score: 0.98 f1 average: 0.981 accuracy: 0.981


 77%|███████▋  | 77/100 [39:38<11:50, 30.88s/it]

test BCE loss:   f1 score: 0.773 f1 average: 0.767 accuracy: 0.767
train BCE loss:  0.04371669143438339  f1 score: 0.983 f1 average: 0.984 accuracy: 0.984


 78%|███████▊  | 78/100 [40:09<11:19, 30.88s/it]

test BCE loss:   f1 score: 0.74 f1 average: 0.761 accuracy: 0.762
train BCE loss:  0.043777432292699814  f1 score: 0.982 f1 average: 0.983 accuracy: 0.983


 79%|███████▉  | 79/100 [40:40<10:48, 30.88s/it]

test BCE loss:   f1 score: 0.766 f1 average: 0.77 accuracy: 0.77
train BCE loss:  0.04797523468732834  f1 score: 0.981 f1 average: 0.982 accuracy: 0.982


 80%|████████  | 80/100 [41:11<10:17, 30.89s/it]

test BCE loss:   f1 score: 0.764 f1 average: 0.776 accuracy: 0.777
train BCE loss:  0.03714672476053238  f1 score: 0.986 f1 average: 0.987 accuracy: 0.987


 81%|████████  | 81/100 [41:42<09:46, 30.88s/it]

test BCE loss:   f1 score: 0.751 f1 average: 0.77 accuracy: 0.772
train BCE loss:  0.033320240676403046  f1 score: 0.987 f1 average: 0.988 accuracy: 0.988


 82%|████████▏ | 82/100 [42:13<09:15, 30.88s/it]

test BCE loss:   f1 score: 0.76 f1 average: 0.775 accuracy: 0.776
train BCE loss:  0.03009747341275215  f1 score: 0.988 f1 average: 0.989 accuracy: 0.989


 83%|████████▎ | 83/100 [42:44<08:44, 30.88s/it]

test BCE loss:   f1 score: 0.765 f1 average: 0.77 accuracy: 0.77
train BCE loss:  0.03299731761217117  f1 score: 0.987 f1 average: 0.988 accuracy: 0.988


 84%|████████▍ | 84/100 [43:15<08:14, 30.90s/it]

test BCE loss:   f1 score: 0.769 f1 average: 0.764 accuracy: 0.764
train BCE loss:  0.042350225150585175  f1 score: 0.982 f1 average: 0.983 accuracy: 0.983


 85%|████████▌ | 85/100 [43:45<07:43, 30.89s/it]

test BCE loss:   f1 score: 0.763 f1 average: 0.773 accuracy: 0.773
train BCE loss:  0.029017580673098564  f1 score: 0.989 f1 average: 0.989 accuracy: 0.989


 86%|████████▌ | 86/100 [44:16<07:12, 30.88s/it]

test BCE loss:   f1 score: 0.76 f1 average: 0.77 accuracy: 0.77
train BCE loss:  0.022684883326292038  f1 score: 0.992 f1 average: 0.992 accuracy: 0.992


 87%|████████▋ | 87/100 [44:47<06:41, 30.87s/it]

test BCE loss:   f1 score: 0.761 f1 average: 0.774 accuracy: 0.774
train BCE loss:  0.02700663171708584  f1 score: 0.989 f1 average: 0.99 accuracy: 0.99


 88%|████████▊ | 88/100 [45:18<06:10, 30.90s/it]

test BCE loss:   f1 score: 0.767 f1 average: 0.775 accuracy: 0.775
train BCE loss:  0.03053904138505459  f1 score: 0.988 f1 average: 0.989 accuracy: 0.989


 89%|████████▉ | 89/100 [45:49<05:39, 30.89s/it]

test BCE loss:   f1 score: 0.77 f1 average: 0.771 accuracy: 0.771
train BCE loss:  0.03236107528209686  f1 score: 0.987 f1 average: 0.988 accuracy: 0.988


 90%|█████████ | 90/100 [46:20<05:08, 30.88s/it]

test BCE loss:   f1 score: 0.781 f1 average: 0.775 accuracy: 0.775
train BCE loss:  0.03775683417916298  f1 score: 0.985 f1 average: 0.986 accuracy: 0.986


 91%|█████████ | 91/100 [46:51<04:38, 30.90s/it]

test BCE loss:   f1 score: 0.772 f1 average: 0.769 accuracy: 0.769
train BCE loss:  0.022721217945218086  f1 score: 0.991 f1 average: 0.992 accuracy: 0.992


 92%|█████████▏| 92/100 [47:22<04:07, 30.89s/it]

test BCE loss:   f1 score: 0.759 f1 average: 0.766 accuracy: 0.766
train BCE loss:  0.03005635365843773  f1 score: 0.989 f1 average: 0.99 accuracy: 0.99


 93%|█████████▎| 93/100 [47:53<03:36, 30.88s/it]

test BCE loss:   f1 score: 0.765 f1 average: 0.774 accuracy: 0.774
train BCE loss:  0.02578273043036461  f1 score: 0.99 f1 average: 0.99 accuracy: 0.99


 94%|█████████▍| 94/100 [48:23<03:05, 30.88s/it]

test BCE loss:   f1 score: 0.763 f1 average: 0.769 accuracy: 0.769
train BCE loss:  0.02658899873495102  f1 score: 0.99 f1 average: 0.99 accuracy: 0.99


 95%|█████████▌| 95/100 [48:54<02:34, 30.89s/it]

test BCE loss:   f1 score: 0.765 f1 average: 0.775 accuracy: 0.776
train BCE loss:  0.023756619542837143  f1 score: 0.991 f1 average: 0.992 accuracy: 0.992


 96%|█████████▌| 96/100 [49:25<02:03, 30.89s/it]

test BCE loss:   f1 score: 0.762 f1 average: 0.774 accuracy: 0.774
train BCE loss:  0.02210243232548237  f1 score: 0.992 f1 average: 0.992 accuracy: 0.992


 97%|█████████▋| 97/100 [49:56<01:32, 30.88s/it]

test BCE loss:   f1 score: 0.768 f1 average: 0.775 accuracy: 0.775
train BCE loss:  0.021954752504825592  f1 score: 0.992 f1 average: 0.993 accuracy: 0.993


 98%|█████████▊| 98/100 [50:27<01:01, 30.87s/it]

test BCE loss:   f1 score: 0.772 f1 average: 0.765 accuracy: 0.765
train BCE loss:  0.032881442457437515  f1 score: 0.987 f1 average: 0.988 accuracy: 0.988


 99%|█████████▉| 99/100 [50:58<00:30, 30.90s/it]

test BCE loss:   f1 score: 0.773 f1 average: 0.774 accuracy: 0.774
train BCE loss:  0.025012295693159103  f1 score: 0.99 f1 average: 0.99 accuracy: 0.99


100%|██████████| 100/100 [51:29<00:00, 30.89s/it]

test BCE loss:   f1 score: 0.757 f1 average: 0.776 accuracy: 0.777





In [23]:
torch.save(model.state_dict(), "charformer-100.pt")

In [None]:
!nvidia-smi

In [26]:
test_data=[]
mi=10000000000000000
ma=-1
for i in range(0,len(X_test)):
    test_data.append(eval_conversation(X_test[i]));
    mi=min(test_data[i].shape[1],mi)
    ma=max(test_data[i].shape[1],ma)

print(mi,ma)
new_test_data=[]

for i in range(0,len(test_data)):
    z = torch.zeros(1, 926-test_data[i].shape[1])
    s = torch.cat((test_data[i],z),1)
    new_test_data.append(s)

new_test_data = torch.stack(new_test_data)

y_test=torch.Tensor(y_test)

new_test_data = torch.squeeze(new_test_data)
print(new_test_data.shape)

9 228
torch.Size([5724, 926])


In [27]:
model = MetNet().to(device)
model.load_state_dict(torch.load("charformer-100.pt"))
model.eval()

MetNet(
  (tokenizer): GBST(
    (token_emb): Embedding(257, 512)
    (pos_conv): Sequential(
      (0): Pad()
      (1): Rearrange('b n d -> b d n')
      (2): DepthwiseConv1d(
        (conv): Conv1d(512, 512, kernel_size=(4,), stride=(1,), groups=512)
        (proj_out): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
      )
      (3): Rearrange('b d n -> b n d')
    )
    (score_fn): Sequential(
      (0): Linear(in_features=512, out_features=1, bias=True)
      (1): Rearrange('... () -> ...')
    )
  )
  (positionEmbeddings): Embedding(512, 64)
  (transformerLayer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=576, out_features=576, bias=True)
    )
    (linear1): Linear(in_features=576, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=576, bias=True)
    (norm1): LayerNorm((576,), eps=1e-05, elementwise_affine=True)
    (nor

In [28]:
import torch.utils.data as data_utils
test=data_utils.TensorDataset(new_test_data ,y_val)


test_loader = torch.utils.data.DataLoader(test, batch_size=32,
               shuffle=True, num_workers=0, pin_memory=True)

In [30]:
testpreds = torch.tensor([])
testtrues = torch.tensor([])
for i, (input, target) in enumerate(val_loader):
    input_var = input.cuda()
    target_var = target.cuda()
    optimizer.zero_grad()
    output = model(input_var)
    output = torch.squeeze(output, 1)

    testpreds = torch.cat((testpreds,output.cpu().detach()))
    testtrues = torch.cat((testtrues,target_var.cpu().detach()))

    loss = F.binary_cross_entropy(output,target_var)

print("test BCE loss: ",calculateMetrics(torch.round(testpreds).numpy(),testtrues.numpy()))   

test BCE loss:   f1 score: 0.756 f1 average: 0.775 accuracy: 0.776
