In [None]:
import os
import pandas as pd
import numpy as np
import shutil
import sys
import tqdm.notebook as tq
from collections import defaultdict
import argparse

from sklearn.preprocessing import OneHotEncoder

import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel

import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import log_normal, log_normal_mixture

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

ModuleNotFoundError: No module named 'torch'

In [None]:
"""
Dataset preparation. Apply for train and test dataset
"""
dataset = pd.read_csv('train_processed_data.csv', low_memory=False)

# Step 2: Extract unique labels from the 'category' column
categories = dataset['sub_category'].unique()

# Step 3: Apply one-hot encoding
encoder = OneHotEncoder()  # sparse=False returns a dense array
one_hot_vectors = encoder.fit_transform(dataset[['sub_category']])

# Step 3: Check the shape of the one-hot encoded array and the number of columns
print(f"One-hot encoded array shape: {one_hot_vectors.shape}")
print(f"Columns from OneHotEncoder: {encoder.get_feature_names_out(['sub_category'])}")

one_hot_vectors = one_hot_vectors.toarray()

# Step 3: Store one-hot vectors as lists in a new column in the original DataFrame
dataset['one_hot_vectors'] = one_hot_vectors.tolist()

One-hot encoded array shape: (93686, 36)
Columns from OneHotEncoder: ['sub_category_Against Interest of sovereignty or integrity of India'
 'sub_category_Business Email CompromiseEmail Takeover'
 'sub_category_Cheating by Impersonation'
 'sub_category_Cryptocurrency Fraud'
 'sub_category_Cyber Bullying  Stalking  Sexting'
 'sub_category_Cyber Terrorism'
 'sub_category_Damage to computer computer systems etc'
 'sub_category_Data Breach/Theft'
 'sub_category_DebitCredit Card FraudSim Swap Fraud'
 'sub_category_DematDepository Fraud'
 'sub_category_Denial of Service (DoS)/Distributed Denial of Service (DDOS) attacks'
 'sub_category_EMail Phishing' 'sub_category_EWallet Related Fraud'
 'sub_category_Email Hacking' 'sub_category_FakeImpersonating Profile'
 'sub_category_Fraud CallVishing' 'sub_category_Hacking/Defacement'
 'sub_category_Impersonating Email'
 'sub_category_Internet Banking Related Fraud'
 'sub_category_Intimidating Email' 'sub_category_Malware Attack'
 'sub_category_Online G

In [15]:
dataset

Unnamed: 0,category,sub_category,crimeaditionalinfo,crimeaditionalinfo_preprocessed,one_hot_vectors
0,Online and Social Media Related Crime,Cyber Bullying Stalking Sexting,I had continue received random calls and abusi...,continu receiv random call abus messag whatsap...,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
1,Online Financial Fraud,Fraud CallVishing,The above fraudster is continuously messaging ...,fraudster continu messag ask pay money send fa...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,Online Gambling Betting,Online Gambling Betting,He is acting like a police and demanding for m...,act like polic demand money ad section text me...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,Online and Social Media Related Crime,Online Job Fraud,In apna Job I have applied for job interview f...,apna job appli job interview telecal resourc m...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,Online Financial Fraud,Fraud CallVishing,I received a call from lady stating that she w...,receiv call ladi state send new phone vivo rec...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
...,...,...,...,...,...
93681,Online Financial Fraud,Internet Banking Related Fraud,Identity theft Smishing SMS Fraud CreditDeb...,ident theft smish sm fraud creditdebit card fr...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
93682,Online Financial Fraud,EWallet Related Fraud,RECEIVED CALL FROM NUMBER ASKING ABOUT phone ...,receiv call number ask phone pay cash back off...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
93683,Online Financial Fraud,UPI Related Frauds,Cyber Stalking Blackmailing PhoneSMSVOIP C...,cyber stalk blackmail phonesmsvoip call victim...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
93684,Online and Social Media Related Crime,Online Matrimonial Fraud,Call karke bola ki aapka lotary laga ha aru AC...,call kark bola ki aapka lotari laga ha aru ac ...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [None]:
# Hyperparameters
MAX_LEN = 256
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-05
num_labels=len(categories)

# pytorch will be run on cpu
seed=900
paser = argparse.ArgumentParser()
args = paser.parse_args("")
np.random.seed(seed)
torch.manual_seed(seed)

# General Model variables
args.seed=seed
args.lr=1e-3
args.lr_decay_ratio=0.9
args.lr_decay_times=4
args.nll_coeff=2
args.l2_coeff=1
args.c_coeff=0
args.class_weights=1  # Find the class weights depending on count
args.current_step=0 # For Plotting the loss and other variables w.r.t iterations

# Set Dataloader parameters
args.BATCH_SIZE = 64
args.NUM_WORKERS = 8
args.max_epoch=300

# scheduler variables
args.eta_min=2e-4
args.T_mult=2
args.T0=50
args.retrain=input("Enter True or False if you want to retrain")=='True'

# VAE model variables
args.latent_dim=128
args.drop=0.1 # dropout value
args.feature_dim=256 # Output dimension from flattened ResNet
args.emb_size=128 # What is emb_size ?
args.label_dim=num_labels # Label Dimension
args.param_setting = "lr-{}_lr-decay_{:.2f}_lr-times_{:.1f}_nll-{:.2f}_l2-{:.2f}_c-{:.2f}".format(args.lr, args.lr_decay_ratio, args.lr_decay_times, args.nll_coeff, args.l2_coeff, args.c_coeff)


In [None]:
# Use Bert tokenizer or any other tokenizer deemed suitable
# Test the tokenizer
test_text = "We are testing BERT tokenizer."
# generate encodings

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

encodings = tokenizer.encode_plus(test_text, 
                                  add_special_tokens = True,
                                  max_length = 50,
                                  truncation = True,
                                  padding = "max_length", 
                                  return_attention_mask = True, 
                                  return_tensors = "pt")

In [None]:
"""
Custom Data Loader
"""
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_column):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['crimeaditionalinfo_preprocessed'])
        self.targets = self.df[target_column]
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }

In [None]:
train_dataset = CustomDataset(dataset, tokenizer, MAX_LEN, 'one_hot_vectors')

# Change the test and val dataset 
valid_dataset = CustomDataset(dataset, tokenizer, MAX_LEN, 'one_hot_vectors')
test_dataset = CustomDataset(dataset, tokenizer, MAX_LEN, 'one_hot_vectors')

In [None]:
# Data loaders
train_data_loader = torch.utils.data.DataLoader(train_dataset, 
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset, 
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset, 
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

In [None]:
"""
Define Model here
"""
class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased', return_dict=True)
        self.dropout = torch.nn.Dropout(0.3)
        self.linear = torch.nn.Linear(768, 256)  # Take care to change outputs of Bert class as what would be inputs of VAE
    
    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids, 
            attention_mask=attn_mask, 
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output

In [None]:
"""
VAE model
"""

class base_class(nn.Module):
    def __init__(self,args):
        super(base_class, self).__init__()
        self.args=args
        self.model=BERTClass() #weights='ResNet101_Weights.IMAGENET1K_V1'
        #self.model.fc=nn.Flatten() # Flatten the last layer
        print(self.model)
    def forward(self,ids, mask, token_type_ids):
        return self.model(ids, mask, token_type_ids)

class VAE(nn.Module):
    def __init__(self, args):
        super(VAE, self).__init__()
        self.args = args
        self.dropout = nn.Dropout(p=args.drop)

        self.base_model=base_class(args)
        
        """Feature encoder"""
        self.fx = nn.Sequential(
            nn.Linear(args.feature_dim, 256), # Set args.feature dim according to flatten shape. By default it is 2048
            nn.ReLU(),
            self.dropout,
            nn.Linear(256, 512,bias=True),
            nn.ReLU(),
            self.dropout,
            nn.Linear(512, 512,bias=True),
            nn.ReLU(),
            self.dropout,
            nn.Linear(512, 256,bias=True),
            nn.ReLU(),
            self.dropout
        )
        self.fx_mu = nn.Linear(256, args.latent_dim,bias=True)
        self.fx_logvar = nn.Linear(256, args.latent_dim,bias=True)

        """Label encoder"""
        self.label_lookup = nn.Linear(args.label_dim, args.emb_size)
        self.fe = nn.Sequential(
            nn.Linear(args.emb_size, 512,bias=True),
            nn.ReLU(),
            self.dropout,
            nn.Linear(512, 256,bias=True),
            nn.ReLU(),
            self.dropout
        )
        self.fe_mu = nn.Linear(256, args.latent_dim,bias=True)
        self.fe_logvar = nn.Linear(256, args.latent_dim,bias=True)

        """Decoder"""
        self.fd = nn.Sequential(
            nn.Linear(args.feature_dim + args.latent_dim, 512,bias=True),
            nn.ReLU(),
            nn.Linear(512, args.emb_size,bias=True),
            nn.LeakyReLU()
        )
        
        # Adaptive weight loss
        import torch.nn.init as init
        # Define the linear layer
        self.linear_layer_weight = nn.Sequential(
            nn.Linear(256, 128,bias=True),
            nn.ReLU(),
            nn.Linear(128, num_labels,bias=True), #Changed. Before working prop
            nn.Softmax(),#Changed. Before working prop
        )

    def label_encode(self, x):
        h0 = self.dropout(F.relu(self.label_lookup(x)))
        h = self.fe(h0)
        mu = self.fe_mu(h)
        logvar = self.fe_logvar(h)
        fe_output = {
            'fe_mu': mu,
            'fe_logvar': logvar
        }
        return fe_output

    def feat_encode(self, x):
        #print(x.shape)
        h = self.fx(x)
        mu = self.fx_mu(h)
        logvar = self.fx_logvar(h)
        fx_output = {
            'fx_mu': mu,
            'fx_logvar': logvar
        }
        return fx_output

    def decode(self, z):
        d = self.fd(z)
        d = F.normalize(d, dim=1)
        return d

    def label_forward(self, x, feat):
        n_label = x.shape[1]
        all_labels = torch.eye(n_label).to(device)
        fe_output = self.label_encode(all_labels)
        mu = fe_output['fe_mu']
        
        z = torch.matmul(x, mu) / x.sum(1, keepdim=True)
        #print(feat.shape,z.shape)
        label_emb = self.decode(torch.cat((feat, z), 1))

        fe_output['label_emb'] = label_emb
        return fe_output

    def adaptive(self,x):
        # Adaptive weight
        x=self.linear_layer_weight(x)
        return x
    
    def feat_forward(self, x):
        fx_output = self.feat_encode(x)
        mu = fx_output['fx_mu']
        logvar = fx_output['fx_logvar']

        if not self.training:
            z = mu
            z2 = mu
        else:
            z = reparameterize(mu, logvar)
            z2 = reparameterize(mu, logvar)
        feat_emb = self.decode(torch.cat((x, z), 1))
        feat_emb2 = self.decode(torch.cat((x, z2), 1))
        fx_output['feat_emb'] = feat_emb
        fx_output['feat_emb2'] = feat_emb2
        return fx_output

    def forward(self, label, ids, mask, token_type_ids):
        # Apply resnet model to get feature embeddings
        feature=self.base_model(ids, mask, token_type_ids)
        #w_1=self.adaptive(feature) # Changed
        fe_output = self.label_forward(label, feature)
        label_emb = fe_output['label_emb']
        fx_output = self.feat_forward(feature)
        feat_emb, feat_emb2 = fx_output['feat_emb'], fx_output['feat_emb2']

        embs = self.label_lookup.weight
        label_out = torch.matmul(label_emb, embs)
        feat_out = torch.matmul(feat_emb, embs)
        feat_out2 = torch.matmul(feat_emb2, embs)
        
        fe_output.update(fx_output)
        output = fe_output
        output['embs'] = embs
        output['label_out'] = label_out
        output['feat_out'] = feat_out
        output['feat_out2'] = feat_out2
        output['feat'] = feature
        #output['weight_loss']=w_1 # Changed
        #print("W1",w_1)
        return output


def reparameterize(mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

import torch.nn.functional as F

def compute_loss(input_label, output, args=None):
    fe_out, fe_mu, fe_logvar, label_emb = \
        output['label_out'], output['fe_mu'], output['fe_logvar'], output['label_emb']
    fx_out, fx_mu, fx_logvar, feat_emb = \
        output['feat_out'], output['fx_mu'], output['fx_logvar'], output['feat_emb']
    fx_out2 = output['feat_out2']
    embs = output['embs']

    fx_sample = reparameterize(fx_mu, fx_logvar)
    fx_var = torch.exp(fx_logvar)
    fe_var = torch.exp(fe_logvar)
    kl_loss = (log_normal(fx_sample, fx_mu, fx_var) - \
        log_normal_mixture(fx_sample, fe_mu, fe_var, input_label)).mean()

    #pred_e = torch.sigmoid(fe_out)
    #pred_x = torch.sigmoid(fx_out)
    #pred_x2 = torch.sigmoid(fx_out2)

    # Use softmax for multi-class classification
    pred_e = F.softmax(fe_out, dim=1)
    pred_x = F.softmax(fx_out, dim=1)
    pred_x2 = F.softmax(fx_out2, dim=1)

     # Define cross-entropy loss for multi-class classification
    def compute_CE_loss(pred):
        return F.cross_entropy(pred, input_label)

    def compute_BCE_and_RL_loss(E):
        #compute negative log likelihood (BCE loss) for each sample point
        sample_nll = -(
            #torch.mul((torch.log(E) * input_label + torch.log(1 - E) * (1 - input_label)),output['weight_loss']) # Changed here to add adaptive weights
            (torch.log(E) * input_label + torch.log(1 - E) * (1 - input_label))
        )
        logprob = -torch.sum(sample_nll, dim=2)

        #the following computation is designed to avoid the float overflow (log_sum_exp trick)
        maxlogprob = torch.max(logprob, dim=0)[0]
        Eprob = torch.mean(torch.exp(logprob - maxlogprob), axis=0)
        nll_loss = torch.mean(-torch.log(Eprob) - maxlogprob)
        return nll_loss

    def supconloss(label_emb, feat_emb, embs, temp=1.0):
        features = torch.cat((label_emb, feat_emb))
        labels = torch.cat((input_label, input_label)).float()
        n_label = labels.shape[1]
        emb_labels = torch.eye(n_label).to(device)
        mask = torch.matmul(labels, emb_labels)

        anchor_dot_contrast = torch.div(
            torch.matmul(features, embs),
            temp)
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        exp_logits = torch.exp(logits)
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
        loss = -mean_log_prob_pos
        loss = loss.mean()
        return loss

    nll_loss = compute_CE_loss(pred_e) #compute_BCE_and_RL_loss(pred_e.unsqueeze(0))
    nll_loss_x = compute_CE_loss(pred_x) #compute_BCE_and_RL_loss(pred_x.unsqueeze(0))
    nll_loss_x2 = compute_CE_loss(pred_x2) #compute_BCE_and_RL_loss(pred_x2.unsqueeze(0))
    sum_nll_loss = nll_loss + nll_loss_x + nll_loss_x2
    cpc_loss = supconloss(label_emb, feat_emb, embs)
    total_loss = sum_nll_loss * args.nll_coeff + kl_loss * 6. + cpc_loss
    return total_loss, nll_loss, nll_loss_x, 0., 0., kl_loss, cpc_loss, pred_e, pred_x

In [None]:
model=VAE(args).to(device)

In [None]:
# Define Optimizer
# Separate out the model parameters if required
args.optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=5e-4)
args.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(args.optimizer, eta_min=args.eta_min, T_0=args.T0, T_mult=args.T_mult)

In [None]:
### Load pretrained model here
#model=torch.load()

In [None]:
# Define the losses

# smooth means average. Every batch has a mean loss value w.r.t. different losses
smooth_nll_loss=0.0 # label encoder decoder cross entropy loss
smooth_nll_loss_x=0.0 # feature encoder decoder cross entropy loss
smooth_c_loss = 0.0 # label encoder decoder ranking loss
smooth_c_loss_x=0.0 # feature encoder decoder ranking loss
smooth_kl_loss = 0.0 # kl divergence
smooth_total_loss=0.0 # total loss
smooth_macro_f1 = 0.0 # macro_f1 score
smooth_micro_f1 = 0.0 # micro_f1 score

best_loss = 1e10
best_iter = 0
best_macro_f1 = 0.0 # best macro f1 for ckpt selection in validation
best_micro_f1 = 0.0 # best micro f1 for ckpt selection in validation
best_acc = 0.0 # best subset acc for ckpt selction in validation

temp_label=[]
temp_pred_x=[]


best_test_metrics = None

loss_list = ["nll_loss",
             "total_loss",
             "acc",
             "hamm_acc",
             "macro_f1",
             "micro_f1",
             'auc',
             'roc']

train_losses = {i:[] for i in loss_list}
test_losses = {i:[] for i in loss_list}

In [None]:
# Train Function
import tqdm as tqdm
import evals

def train(model,train_loader, args):
    print("Training Started")
    counter=0
    model=model.to(device)
    # smooth means average. Every batch has a mean loss value w.r.t. different losses
    smooth_nll_loss=0.0 # label encoder decoder cross entropy loss
    smooth_nll_loss_x=0.0 # feature encoder decoder cross entropy loss
    smooth_c_loss = 0.0 # label encoder decoder ranking loss
    smooth_c_loss_x=0.0 # feature encoder decoder ranking loss
    smooth_kl_loss = 0.0 # kl divergence
    smooth_total_loss=0.0 # total loss
    smooth_macro_f1 = 0.0 # macro_f1 score
    smooth_micro_f1 = 0.0 # micro_f1 score
    
    for _,data in tqdm(enumerate(train_loader),desc='Training'):
        #print("Entered Loop")
        ids = data['input_ids'].to(device, dtype = torch.long)
        mask = data['attention_mask'].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float)
        # print(targets)
        args.optimizer.zero_grad()
        output = model(targets,ids, mask, token_type_ids)
        #print("Output Done!")
        total_loss, nll_loss, nll_loss_x, c_loss, c_loss_x, kl_loss, cpc_loss, _, pred_x = \
                    compute_loss(targets, output, args)
        total_loss=total_loss
        total_loss.backward()
        args.optimizer.step()

        train_metrics = evals.compute_metrics(pred_x.cpu().data.numpy(), targets.cpu().data.numpy(), 0.5, all_metrics=False)
        macro_f1, micro_f1 = train_metrics['maF1'], train_metrics['miF1']
       
        smooth_nll_loss += nll_loss
        smooth_nll_loss_x += nll_loss_x
        smooth_c_loss += c_loss
        smooth_c_loss_x += c_loss_x
        smooth_kl_loss += kl_loss
        smooth_total_loss += total_loss
        smooth_macro_f1 += macro_f1
        smooth_micro_f1 += micro_f1

        counter+=1
        #print("Train Func",counter)
        #del x,targets,outputs,
        
    nll_loss = smooth_nll_loss / counter
    nll_loss_x = smooth_nll_loss_x / counter
    c_loss = smooth_c_loss / counter
    c_loss_x = smooth_c_loss_x / counter
    kl_loss = smooth_kl_loss / counter
    total_loss = smooth_total_loss / counter
    macro_f1 = smooth_macro_f1 / counter
    micro_f1 = smooth_micro_f1 / counter
       
  
    return model, train_metrics, nll_loss, nll_loss_x, c_loss, c_loss_x, kl_loss, total_loss, macro_f1, micro_f1
       

In [None]:
# Test function
from sklearn.metrics import roc_auc_score
test_temp_label = []
test_temp_pred_x = []
def test(model, test_loader, args):
    counter = 0
    model.eval()
    # smooth means average. Every batch has a mean loss value w.r.t. different losses
    smooth_nll_loss=0.0 # label encoder decoder cross entropy loss
    smooth_nll_loss_x=0.0 # feature encoder decoder cross entropy loss
    smooth_c_loss = 0.0 # label encoder decoder ranking loss
    smooth_c_loss_x=0.0 # feature encoder decoder ranking loss
    smooth_kl_loss = 0.0 # kl divergence
    smooth_total_loss=0.0 # total loss
    smooth_macro_f1 = 0.0 # macro_f1 score
    smooth_micro_f1 = 0.0 # micro_f1 score
    with torch.no_grad():
     for _,data in tqdm(enumerate(test_loader),desc='Testing'):
        ids = data['input_ids'].to(device, dtype = torch.long)
        mask = data['attention_mask'].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float)
      
        output = model(targets,ids, mask, token_type_ids)

        total_loss, nll_loss, nll_loss_x, c_loss, c_loss_x, kl_loss, cpc_loss, _, pred_x = \
                    compute_loss(targets, output, args)
        
        auROC = roc_auc_score(targets, pred_x)
        test_metrics = evals.compute_metrics(pred_x.cpu().data.numpy(), targets.cpu().data.numpy(), 0.5, all_metrics=False)
        macro_f1, micro_f1 = test_metrics['maF1'], test_metrics['miF1']
       
        smooth_nll_loss += nll_loss
        smooth_nll_loss_x += nll_loss_x
        smooth_c_loss += c_loss
        smooth_c_loss_x += c_loss_x
        smooth_kl_loss += kl_loss
        smooth_total_loss += total_loss
        smooth_macro_f1 += macro_f1
        smooth_micro_f1 += micro_f1

        counter+=1
        
    nll_loss = smooth_nll_loss / counter
    nll_loss_x = smooth_nll_loss_x / counter
    c_loss = smooth_c_loss / counter
    c_loss_x = smooth_c_loss_x / counter
    kl_loss = smooth_kl_loss / counter
    total_loss = smooth_total_loss / counter
    macro_f1 = smooth_macro_f1 / counter
    micro_f1 = smooth_micro_f1 / counter
          
    return test_metrics, nll_loss, nll_loss_x, c_loss, c_loss_x, kl_loss, total_loss, macro_f1, micro_f1,auROC

In [None]:
args.a_max=0
def experiment(model, train_loader, device, args):

    for epoch in range(args.max_epoch):
        model.train()
        model, train_metrics, nll_loss, nll_loss_x, c_loss, c_loss_x, kl_loss, total_loss, macro_f1, micro_f1= train(model, train_loader, args)
        train_acc=train_metrics['ACC']
        train_ha_acc=train_metrics['HA']
        print('- Epoch :', epoch+1)
        print('*** Training Metrics ***')
        print('- NLL Loss : %.5f' % nll_loss,'- Total Loss : %.5f' % total_loss, '- Total Accuracy : %.5f',train_acc,'- Hamming Accuracy : %.3f',train_ha_acc,'- Micro F1 : %.3f',train_metrics['miF1'], '- Macro F1 : %.3f',train_metrics['maF1'])
        
        train_losses["nll_loss"].append(nll_loss)
        train_losses["total_loss"].append(total_loss)
        train_losses["acc"].append(train_acc)
        train_losses["hamm_acc"].append(train_ha_acc)
        train_losses["macro_f1"].append(train_metrics['miF1'])
        train_losses["micro_f1"].append(train_metrics['maF1'])
        train_losses["auc"].append(train_metrics['meanAUC'])
       
        
        # Validation
        test_metrics, nll_loss, nll_loss_x, c_loss, c_loss_x, kl_loss, total_loss, macro_f1, micro_f1,aucROC=test(model,test_data_loader,args)
        test_acc=test_metrics['ACC']
        test_ha_acc=test_metrics['HA']
        print('- Epoch :', epoch+1)
        print('*** Validation Metrics ***')
        print('- NLL Loss : %.5f' % nll_loss,'- Total Loss : %.5f' % total_loss, '- Total Accuracy : %.5f',test_acc,'- Hamming Accuracy : %.3f',test_ha_acc,'- Micro F1 : %.3f',test_metrics['miF1'], '- Macro F1 : %.3f',test_metrics['maF1'])
        print("AUCROC",aucROC)

        args.a_max=max(args.a_max,aucROC)

        test_losses["nll_loss"].append(nll_loss)
        test_losses["total_loss"].append(total_loss)
        test_losses["acc"].append(test_acc)
        test_losses["hamm_acc"].append(test_ha_acc)
        test_losses["macro_f1"].append(test_metrics['miF1'])
        test_losses["micro_f1"].append(test_metrics['maF1'])
        test_losses["auc"].append(test_metrics['meanAUC'])
        
     

        args.scheduler.step()
        torch.save(model, f'model_{args.dataset}_{args.max_epoch}.pt')        
        torch.save(train_losses, f'train_metrics_{args.dataset}_{args.max_epoch}.pt')
        torch.save(test_losses, f'test_metrics_{args.dataset}_{args.max_epoch}.pt')
   
    return args

In [None]:
# Call the train model here
from tqdm import tqdm
args = experiment(model, train_data_loader, device, args)

In [None]:
import matplotlib.pyplot as plt
plt.plot([i.item() for i in test_losses['hamm_acc']])
plt.title("Hamming Accuracy on Tox-Cast Test Set")
plt.xlabel("Epochs")
plt.ylabel("Ha")

In [None]:
import matplotlib.pyplot as plt
# Create a figure with one row and four columns
fig, axes = plt.subplots(1, 6, figsize=(12, 4))  # Adjust figsize as needed
train_losses=torch.load('train_metrics_tox21_200.pt')


# Plot each list in a separate subplot
axes[0].plot([i.item() for i in train_losses['nll_loss']])
axes[0].set_title('NLL LOSS')

axes[1].plot([i.item() for i in  train_losses['acc']])
axes[1].set_title('TOTAL ACCURACY')

axes[2].plot([i.item() for i in train_losses['hamm_acc']])
axes[2].set_title('HAMMING ACCURACY')

axes[3].plot([ i.item() for i in train_losses['macro_f1']])
axes[3].set_title('MACRO F1')

axes[4].plot([ i.item() for i in train_losses['micro_f1']])
axes[4].set_title('MICRO F1')

axes[5].plot([ i.item() for i in train_losses['total_loss']])
axes[5].set_title('TOTAL LOSS')


# Adjust layout for better visualization
plt.tight_layout()

# Show the plot
plt.show()