In [1]:
import os
import torch
import json
import pickle
from custom_data import CustomDataset
from tokenizer import Tokenizer
from torch.utils.data import DataLoader
from train import evaluate
from torch import nn
import numpy as np
import pandas as pd
import plotly.express  as px
from tqdm import tqdm
from sklearn.metrics import f1_score,accuracy_score

EPSILON = 1e-32

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint_weight = "/home/Ravikumar/Developer/Learning-NLP-with-PyTorch/Chapter - 7 : Detailed Understanding of Multi-Label Classification/dataset/checkpoints/Transformer_multilabel-classification_added_pos_weight_dropout1_dropout2_mean_bs_12_lr_0.0001.pt"
checkpoint_pos_weight = "/home/Ravikumar/Developer/Learning-NLP-with-PyTorch/Chapter - 7 : Detailed Understanding of Multi-Label Classification/dataset/checkpoints/Transformer_multilabel-classification_added_weight_and_pos_weight_update_bs_64_lr_0.0001.pt"

checkpoints_weight = torch.load(checkpoint_weight)
# checkpoints_pos_weight = torch.load(checkpoint_pos_weight)

In [3]:
args_dict = checkpoints_weight['params']

In [4]:
# args_dict =  {k:v for k,v in config.__dict__.items() if "__" not in k}

vocab = json.load(open(args_dict["VOCAB_DIR"],"r"))
word_to_token = {k:int(v) for k,v in vocab["vocabs"].items()}
args_dict["N_WORDS"] = len(word_to_token)

In [5]:
if os.path.isfile(args_dict["CHECKPOINT_NAME"]):
    checkpoints = torch.load(args_dict["CHECKPOINT_NAME"])


In [6]:
vocab = json.load(open(args_dict["VOCAB_DIR"],"r"))
vocab_len = len(vocab["vocabs"])  + 1

In [8]:
from model import ClassifierModel
model = ClassifierModel(vocab_len,args_dict).to(args_dict['DEVICE'])
model.load_state_dict(checkpoints["model_state_dict"])
model.eval()

ClassifierModel(
  (embs): Embedding(50940, 64)
  (linear1): Linear(in_features=64, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=512, bias=True)
  (dropout1): Dropout(p=0.4, inplace=False)
  (dropout2): Dropout(p=0.4, inplace=False)
  (relu): LeakyReLU(negative_slope=0.01)
  (transformer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    )
    (linear1): Linear(in_features=512, out_features=2048, bias=True)
    (dropout): Dropout(p=0.4, inplace=False)
    (linear2): Linear(in_features=2048, out_features=512, bias=True)
    (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.4, inplace=False)
    (dropout2): Dropout(p=0.4, inplace=False)
  )
  (fc): Linear(in_features=512, out_features=6, bias=True)
)

In [9]:
splitted_data = pickle.load(open(args_dict["SPLIT_DATA_DIR"],"rb"))
valid_dataset = splitted_data["valid_dataset"][args_dict["LABELS"] +["pairs"]]
train_dataset = splitted_data["train_dataset"][args_dict["LABELS"] +["pairs"]]

## data

In [149]:
train_dataset

Unnamed: 0,Computer Science,Physics,Mathematics,Statistics,Quantitative Biology,Quantitative Finance,pairs
18146,0,0,1,0,0,0,"('polynomial functors in manifold calculus', '..."
8633,0,1,0,0,0,0,('monolithic ingaas nanowire array lasers on s...
17016,1,1,0,0,0,0,('kinetic simulation of collisional magnetized...
18277,0,0,0,1,0,0,('improving massive mimo belief propagation de...
1072,0,0,0,1,0,0,('few shot learning of neural networks from sc...
...,...,...,...,...,...,...,...
5941,1,0,0,1,0,0,('learning what matters sampling interesting p...
20090,1,0,1,0,0,0,('criteria for solar car optimized route estim...
14254,0,0,1,0,0,0,('angles between curves in metric measure spac...
3756,1,0,0,0,0,0,('json data model query languages and schema s...


In [21]:
train_dataset.iloc[:,:-1].sum(0)

Computer Science        6874
Physics                 4787
Mathematics             4503
Statistics              4189
Quantitative Biology     461
Quantitative Finance     205
dtype: int64

In [23]:
len(train_dataset)

16777

In [24]:
train_dataset.iloc[:,:-1].sum(0).max()

6874

In [25]:

import plotly.graph_objects as go

fig = go.Figure(data=[
    go.Bar(name='Train_True', y=train_dataset.iloc[:,:-1].sum(0)),
    go.Bar(name='Val_True', y=valid_dataset.iloc[:,:-1].sum(0))
])
fig.update_layout(barmode='group')
fig.show()

In [18]:

import plotly.graph_objects as go

fig = go.Figure(data=[
    go.Bar(name='Train_True', y=train_dataset.iloc[:,:-1].sum(0)/train_dataset.iloc[:,:-1].sum(0).max()),
    go.Bar(name='Val_True', y=valid_dataset.iloc[:,:-1].sum(0)/valid_dataset.iloc[:,:-1].sum(0).max())
])
fig.update_layout(barmode='group')
fig.show()

## confusion matrix

In [10]:
tokenizer_obj = Tokenizer("topic-modelling-research-articles")
validation_set = CustomDataset(
    valid_dataset,
    tokenizer=tokenizer_obj,
    args=args_dict,
)

validation_dataloader = DataLoader(
    validation_set,
    batch_size = args_dict["BATCH_SIZE"],
    drop_last = True)

thresholds = np.linspace(0,1,20).round(2)

In [11]:
# threshold = 0.95
score = list()
threshold_remaining = 0.5
for threshold in tqdm(thresholds):
    all_pred,all_true,all_th_pred = [],[],[]
    for batch_data in validation_dataloader:
        input_ids = batch_data['input_ids'].to(args_dict['DEVICE'])
        target = batch_data['target'].to(args_dict['DEVICE'])
        with torch.no_grad():
            logits = model(input_ids) 
            pred = torch.sigmoid(logits)
            pred_clone = pred.clone()
            pred[:,:-2] = (pred[:,:-2]>threshold_remaining).float()
            pred[:,-2:] = (pred[:,-2:]>threshold).float()
            all_pred.extend(pred_clone.round().tolist())
            all_th_pred.extend(pred.round().tolist())
            all_true.extend(target.tolist())
    score.append(f1_score(all_true,all_pred,average='micro'))
score

100%|██████████| 20/20 [01:02<00:00,  3.14s/it]


[0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405,
 0.7702954484961405]

In [33]:
# thresholds = list()
# k = 0
# for _ in range(10):
#     k =+ np.round(k+0.1,2)
#     thresholds.append(k)
# thresholds

In [12]:
weight_score = score.copy()

In [84]:
fig = px.line(
    x = thresholds,y = weight_score,
    markers=True,
    text = weight_score,
    labels={"x":"threshold","y":"f1-score"},
    title="F1-SCORE WITH RESPECT TO THRESHOLD - [added weight and pos_weight]"
)
fig.update_traces(textposition="top center")
fig.show()

In [15]:
from sklearn.metrics import f1_score,accuracy_score
f1_score(all_true,all_pred,average='micro'),accuracy_score(all_true,all_pred)

(0.7702954484961405, 0.5785577841451767)

In [16]:
f1_score(all_true,all_th_pred,average='micro'),accuracy_score(all_true,all_th_pred)

(0.7763472503218688, 0.5986150907354346)

In [228]:
all_pred = all_th_pred.copy()
#Confusion matrix
##pred || true
assert len(all_pred)==len(all_true) and len(all_pred[0])==len(all_true[0])

#pred/True>cls1,cls2
#       cls1   2   5
#       cls2   1   2
unique_cls = {tuple(x):None for x in all_pred+all_true}
unique_cls = {x:idx for idx,x in enumerate(unique_cls)}
print("Unique classes :",len(unique_cls))
labels = args_dict["LABELS"]
unique_labels = ["|"+"|".join([labels[i]+"|" for i,x in enumerate(key) if x!=0]) for key in unique_cls]
confusion_matrix = np.zeros((len(unique_cls),len(unique_cls)))

#loop
#1:cls1_pred|cls2_true
#2:cls2_pred|cls1_true
for pred,true in zip(all_pred,all_true):
    t_idx = unique_cls[tuple(true)]
    p_idx = unique_cls[tuple(pred)]
    confusion_matrix[t_idx,p_idx] +=1

normalized_cm = confusion_matrix/(confusion_matrix.sum(axis=1,keepdims=True)+EPSILON)

df = pd.DataFrame(data = normalized_cm.round(2),columns=unique_labels,index=unique_labels)
assert confusion_matrix.sum() == len(all_pred)

## plot 
fig = px.imshow(df,text_auto=True,aspect="auto")
fig.update_xaxes(side="top",title_font_family="Arial")
fig.update_layout(font=dict(
    size=8
))
fig.show()

Unique classes : 31


In [18]:
fig.write_html("added_weight_dropout_final_.html")