In [1]:
import data_preprocess

data = data_preprocess.run(file_path = './ATP_rmsim.json',slidingwindow=False)


sequences = []
labels = []

MAX_LEN = 128

for item in data:
    while len(item['sequence']) > MAX_LEN:
        sequences.append(item['sequence'][:MAX_LEN])
        labels.append(item['label'][:MAX_LEN])
        item['sequence'] = item['sequence'][MAX_LEN:]
        item['label'] = item['label'][MAX_LEN:]
    sequences.append(item['sequence'])
    labels.append(item['label'])

In [2]:
import pandas as pd

zero_counter = 0
ones_counter = 0

for l in labels:
    zero_counter+=l.count(0)
    ones_counter+=l.count(1)

pd.DataFrame(columns=['sample amount'], data = [zero_counter, ones_counter])

Unnamed: 0,sample amount
0,542522
1,6250


In [3]:
from dataset import FADBindingDataset, split_dataset
from model import LCPLMforSequenceLabeling
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = LCPLMforSequenceLabeling("../../LCPLM/")
dataset = FADBindingDataset(tokenizer, sequences, labels, max_length=MAX_LEN)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at ../../LCPLM/ were not used when initializing LcPlmForMaskedLM: ['bimamba.backbone.layers.10.mixer.mamba_fwd.A_log', 'bimamba.backbone.layers.10.mixer.mamba_fwd.D', 'bimamba.backbone.layers.10.mixer.mamba_fwd.conv1d.bias', 'bimamba.backbone.layers.10.mixer.mamba_fwd.conv1d.weight', 'bimamba.backbone.layers.10.mixer.mamba_fwd.dt_proj.bias', 'bimamba.backbone.layers.10.mixer.mamba_fwd.dt_proj.weight', 'bimamba.backbone.layers.10.mixer.mamba_fwd.in_proj.weight', 'bimamba.backbone.layers.10.mixer.mamba_fwd.out_proj.weight', 'bimamba.backbone.layers.10.mixer.mamba_fwd.x_proj.weight', 'bimamba.backbone.layers.10.mixer.mamba_rev.A_log', 'bimamba.backbone.layers.10.mixer.mamba_rev.D', 'bimamba.backbone.layers.10.mixer.mamba_rev.conv1d.bias', 'bimamba.backbone.layers.10.mixer.mamba_rev.conv1d.weight', 'bimamba.backbone.layers.10.mixer.mamba_rev.dt_proj.bias', 'bimamba.backbone.layers.10.mixer.mamba_rev.dt_p

In [4]:
dataset[0]

{'input_ids': tensor([ 0, 20,  7,  7,  4, 10, 16,  4, 10,  4,  4,  4, 22, 15, 17, 19, 11,  4,
         15, 15, 10, 15,  7,  4,  7, 11,  7,  4,  9,  4, 18,  4, 14,  4,  4, 18,
          8,  6, 12,  4, 12, 22,  4, 10,  4, 15, 12, 16,  8,  9, 17,  7, 14, 17,
          5, 11,  7, 19, 14, 13, 16, 21, 12, 16,  9,  4, 14,  4, 18, 18,  8, 18,
         14, 14, 14,  6,  6,  8, 22,  9,  4,  5, 19,  7, 14,  8, 21,  8, 13,  5,
          5, 10, 11, 12, 11,  9,  5,  7, 10, 10,  9, 18, 20, 12, 15, 20, 10,  7,
         21,  6, 18,  8,  8,  9, 15, 13, 18,  9, 13, 19,  7, 10, 19, 13, 17, 21,
          8,  2]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [5]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from transformers import get_linear_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter

[train_dataset, test_dataset] = split_dataset(dataset, 0.8)
[validate_dataset, test_dataset] = split_dataset(test_dataset, 0.5)


In [6]:
import numpy as np
from sklearn.metrics import classification_report, roc_curve

def find_best_evalution(logits, labels):
    logits_np = np.array(logits)
    labels_np = np.array(labels)
    #sigmoid probability
    print(logits_np.shape)
    probs = 1 / (1 + np.exp(-logits_np[:, :, 1]))#[batch size, sequence length]

    labels_flat = labels_np.flatten()
    probs_flat = probs.flatten()

     # 只取非 padding (-100) 的有效樣本
    active_indices = labels_flat != -100
    final_labels = labels_flat[active_indices]
    final_probs = probs_flat[active_indices]

    # 計算 ROC 並找出最佳 G-mean 閾值
    fpr, tpr, thresholds = roc_curve(final_labels, final_probs)
    gmeans = np.sqrt(tpr * (1 - fpr))
    ix = np.argmax(gmeans)
    best_threshold = thresholds[ix]

    print(f'Best Threshold: {best_threshold:.3f}, G-Mean: {gmeans[ix]:.3f}')
    return best_threshold
    

def custom_classification_report(logits, labels, test_evalution = False):
    """Filt the ignore value and return classification report"""
    labels_np = np.array(labels)
    logits_np = np.array(logits)
    #preds = np.argmax(logits_np, axis=-1).flatten()
    if test_evalution:
        threshold = find_best_evalution(logits, labels)
    else:
        threshold = 0.5
    probs = 1/(1+np.exp(-logits_np[:,:,1]))
    preds = (probs > threshold).astype(int).flatten()
    labels_flat = labels_np.flatten()

    print("Shape of flattened labels:", labels_flat.shape)
    print("Shape of flattened predictions:", preds.shape)

    active_indices = labels_flat != -100
    final_labels = labels_flat[active_indices]
    final_preds = preds[active_indices]
    report_str = classification_report(final_labels, final_preds)
    report_dict = classification_report(final_labels, final_preds, output_dict=True)
    print(report_str)
    return report_dict

In [7]:
from tqdm import tqdm
import yaml

def train_by_dataloader():

    writer = SummaryWriter('runs/my_experiment') # TensorBoard writer

    ## hyperparameters in yaml file
    with open("../config/mamba_config.yaml", "r") as f:
        config = yaml.safe_load(f)
    learning_rate = config["training_arguments"]["learning_rate"]
    per_device_train_batch_size = config["training_arguments"]["per_device_train_batch_size"]
    num_epochs = config["training_arguments"]["num_train_epochs"]
    warmup_steps = config['training_arguments']['warmup_steps']
    log_step = config['training_arguments']['eval_steps']
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    ##

    ## create dataloader
    train_loader = DataLoader(train_dataset, per_device_train_batch_size, shuffle=True)
    validation_loader = DataLoader(validate_dataset, per_device_train_batch_size, shuffle=True)
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    ## 

    ## some variable to use
    patience = 3
    best_val_loss = float('inf')
    best_f1_score = -float('inf')
    epochs_no_improve = 0
    best_model_state_dict = None
    global_step = 0
    ## 

    ## move model to GPU
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    ##
  
    for epoch in range(num_epochs):
        """Training Loop Start here"""
        model.train()  
        total_train_loss = 0  
            
        for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):          
            inputs = {k:v.to(device) for k,v in batch.items() if k != 'labels'}# 將批次數據移至設備  
            labels = batch['labels'].to(device)

            optimizer.zero_grad()# 清除梯度
            loss, _ = model(**inputs, labels=labels)  # 前向傳播  
            loss.backward()  # 反向傳播和優化  
            optimizer.step()
            scheduler.step()
            global_step+=1

            if global_step % log_step == 0:
                avg_train_loss = total_train_loss / (step + 1)
                writer.add_scalar('Loss/train', avg_train_loss, global_step)
                print(f"Epoch {epoch+1}, Step {step+1}, Training Loss: {avg_train_loss:.6f}")
                
            total_train_loss += loss.item()
        """Training Loop End here"""

        """Validate Loop Start here"""
        model.eval()
        total_eval_loss = 0
        all_val_logits = []
        all_val_labels = []
        with torch.no_grad():
            for batch in tqdm(validation_loader, desc=f"Validation Epoch {epoch+1}"):
                inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
                labels = batch['labels'].to(device)

                loss, logits = model(**inputs, labels=labels)

                total_eval_loss += loss.item()
                all_val_logits.extend(logits.cpu().numpy())
                all_val_labels.extend(labels.cpu().numpy())
        """Validate Loop End here"""

        """Report and Save the weight of the model"""
        report = custom_classification_report(all_val_logits, all_val_labels, True)

        avg_eval_loss = total_eval_loss / len(validation_loader)
        val_f1_score = report['1']['f1-score']##means Marco-f1-score on 1's sample prediction
        print(f"Validation Loss: {avg_eval_loss:.6f}")

        # Store the model has best performance on the specific indicator like: 
        # evaluation loss, recall rate or f1-score
        ## early stopping part
        if avg_eval_loss < best_val_loss:
            best_val_loss = avg_eval_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs!")
                break

        ## save weight part
        if val_f1_score > best_f1_score:
            best_f1_score = val_f1_score
            best_model_state_dict = model.state_dict() # 保存最佳權重
        """An epoch end here"""

        writer.add_scalar('Loss/val', avg_eval_loss, epoch + 1)
    writer.close()

    if best_model_state_dict is not None:
        model.load_state_dict(best_model_state_dict)
        print("Loaded best model weights!")
    else:
        print("No improvement found during training.")
    # 保存微調後的模型  
    torch.save(model.state_dict(), "atp_binding_model.pt")

In [8]:
import sys
sys.path.append('..')
from utils import helpers
import time

start_time = time.time()
helpers.log("Start Training")
trainer = train_by_dataloader()
helpers.log("End Training")
end_time = time.time()
helpers.log(f"Total Training Time:{end_time - start_time} seconds")

[2025-05-17 19:37:08]Start Training


Epoch 1:   4%|▎         | 33/919 [00:16<03:44,  3.95it/s] 

Epoch 1, Step 40, Training Loss: 0.166773


Epoch 1:   9%|▊         | 79/919 [00:24<03:35,  3.91it/s]

Epoch 1, Step 80, Training Loss: 0.141021


Epoch 1:  13%|█▎        | 119/919 [00:35<03:38,  3.66it/s]

Epoch 1, Step 120, Training Loss: 0.100598


Epoch 1:  17%|█▋        | 158/919 [00:45<03:24,  3.72it/s]

Epoch 1, Step 160, Training Loss: 0.075785


Epoch 1:  22%|██▏       | 199/919 [00:53<03:10,  3.77it/s]

Epoch 1, Step 200, Training Loss: 0.060760


Epoch 1:  26%|██▌       | 239/919 [01:04<02:56,  3.86it/s]

Epoch 1, Step 240, Training Loss: 0.050752


Epoch 1:  30%|███       | 279/919 [01:14<02:40,  4.00it/s]

Epoch 1, Step 280, Training Loss: 0.043556


Epoch 1:  35%|███▍      | 319/919 [01:21<02:39,  3.77it/s]

Epoch 1, Step 320, Training Loss: 0.038168


Epoch 1:  39%|███▉      | 359/919 [01:32<02:30,  3.71it/s]

Epoch 1, Step 360, Training Loss: 0.033946


Epoch 1:  43%|████▎     | 399/919 [01:43<02:24,  3.61it/s]

Epoch 1, Step 400, Training Loss: 0.030581


Epoch 1:  48%|████▊     | 439/919 [01:51<02:17,  3.50it/s]

Epoch 1, Step 440, Training Loss: 0.027819


Epoch 1:  52%|█████▏    | 480/919 [02:02<01:57,  3.74it/s]

Epoch 1, Step 480, Training Loss: 0.025508


Epoch 1:  56%|█████▋    | 519/919 [02:13<01:50,  3.62it/s]

Epoch 1, Step 520, Training Loss: 0.023556


Epoch 1:  61%|██████    | 559/919 [02:22<01:42,  3.53it/s]

Epoch 1, Step 560, Training Loss: 0.021879


Epoch 1:  65%|██████▌   | 599/919 [02:33<01:34,  3.39it/s]

Epoch 1, Step 600, Training Loss: 0.020440


Epoch 1:  69%|██████▉   | 635/919 [02:44<01:21,  3.48it/s]

Epoch 1, Step 640, Training Loss: 0.019167


Epoch 1:  74%|███████▍  | 679/919 [02:53<01:09,  3.44it/s]

Epoch 1, Step 680, Training Loss: 0.018042


Epoch 1:  78%|███████▊  | 719/919 [03:05<00:59,  3.37it/s]

Epoch 1, Step 720, Training Loss: 0.017049


Epoch 1:  83%|████████▎ | 759/919 [03:13<00:10, 15.18it/s]

Epoch 1, Step 760, Training Loss: 0.016157


Epoch 1:  87%|████████▋ | 799/919 [03:25<00:36,  3.33it/s]

Epoch 1, Step 800, Training Loss: 0.015352


Epoch 1:  91%|█████████▏| 839/919 [03:36<00:23,  3.46it/s]

Epoch 1, Step 840, Training Loss: 0.014629


Epoch 1:  96%|█████████▌| 878/919 [03:45<00:07,  5.70it/s]

Epoch 1, Step 880, Training Loss: 0.013967


Epoch 1: 100%|██████████| 919/919 [03:56<00:00,  3.89it/s]
Validation Epoch 1: 100%|██████████| 115/115 [00:06<00:00, 16.60it/s]


(459, 128, 2)
Best Threshold: 0.165, G-Mean: 0.902
Shape of flattened labels: (58752,)
Shape of flattened predictions: (58752,)
              precision    recall  f1-score   support

           0       1.00      0.96      0.98     53478
           1       0.21      0.85      0.33       689

    accuracy                           0.96     54167
   macro avg       0.60      0.90      0.66     54167
weighted avg       0.99      0.96      0.97     54167

Validation Loss: 0.000110




Epoch 2, Step 1, Training Loss: 0.000000


Epoch 2:   4%|▎         | 33/919 [00:09<03:58,  3.71it/s]

Epoch 2, Step 41, Training Loss: 0.000027


Epoch 2:   9%|▊         | 80/919 [00:19<04:04,  3.43it/s]

Epoch 2, Step 81, Training Loss: 0.000047


Epoch 2:  13%|█▎        | 120/919 [00:31<03:48,  3.50it/s]

Epoch 2, Step 121, Training Loss: 0.000044


Epoch 2:  17%|█▋        | 160/919 [00:40<01:49,  6.90it/s]

Epoch 2, Step 161, Training Loss: 0.000055


Epoch 2:  22%|██▏       | 200/919 [00:51<03:40,  3.26it/s]

Epoch 2, Step 201, Training Loss: 0.000069


Epoch 2:  26%|██▌       | 240/919 [01:03<03:12,  3.53it/s]

Epoch 2, Step 241, Training Loss: 0.000074


Epoch 2:  30%|███       | 279/919 [01:11<02:11,  4.88it/s]

Epoch 2, Step 281, Training Loss: 0.000075


Epoch 2:  35%|███▍      | 320/919 [01:22<02:36,  3.82it/s]

Epoch 2, Step 321, Training Loss: 0.000080


Epoch 2:  39%|███▉      | 360/919 [01:33<02:30,  3.71it/s]

Epoch 2, Step 361, Training Loss: 0.000077


Epoch 2:  44%|████▎     | 400/919 [01:41<02:16,  3.79it/s]

Epoch 2, Step 401, Training Loss: 0.000091


Epoch 2:  48%|████▊     | 440/919 [01:52<02:11,  3.65it/s]

Epoch 2, Step 441, Training Loss: 0.000086


Epoch 2:  52%|█████▏    | 480/919 [02:03<01:59,  3.67it/s]

Epoch 2, Step 481, Training Loss: 0.000082


Epoch 2:  57%|█████▋    | 520/919 [02:11<01:55,  3.46it/s]

Epoch 2, Step 521, Training Loss: 0.000080


Epoch 2:  61%|██████    | 560/919 [02:23<01:38,  3.64it/s]

Epoch 2, Step 561, Training Loss: 0.000079


Epoch 2:  65%|██████▌   | 600/919 [02:33<01:22,  3.85it/s]

Epoch 2, Step 601, Training Loss: 0.000076


Epoch 2:  70%|██████▉   | 640/919 [02:41<01:14,  3.74it/s]

Epoch 2, Step 641, Training Loss: 0.000075


Epoch 2:  74%|███████▍  | 680/919 [02:52<01:09,  3.44it/s]

Epoch 2, Step 681, Training Loss: 0.000076


Epoch 2:  78%|███████▊  | 720/919 [03:03<00:52,  3.76it/s]

Epoch 2, Step 721, Training Loss: 0.000076


Epoch 2:  83%|████████▎ | 760/919 [03:11<00:41,  3.83it/s]

Epoch 2, Step 761, Training Loss: 0.000073


Epoch 2:  87%|████████▋ | 800/919 [03:22<00:32,  3.66it/s]

Epoch 2, Step 801, Training Loss: 0.000071


Epoch 2:  91%|█████████▏| 840/919 [03:33<00:22,  3.45it/s]

Epoch 2, Step 841, Training Loss: 0.000072


Epoch 2:  96%|█████████▌| 880/919 [03:41<00:10,  3.63it/s]

Epoch 2, Step 881, Training Loss: 0.000070


Epoch 2: 100%|██████████| 919/919 [03:52<00:00,  3.95it/s]
Validation Epoch 2: 100%|██████████| 115/115 [00:07<00:00, 15.07it/s]


(459, 128, 2)
Best Threshold: 0.146, G-Mean: 0.920
Shape of flattened labels: (58752,)
Shape of flattened predictions: (58752,)
              precision    recall  f1-score   support

           0       1.00      0.96      0.98     53478
           1       0.24      0.88      0.38       689

    accuracy                           0.96     54167
   macro avg       0.62      0.92      0.68     54167
weighted avg       0.99      0.96      0.97     54167

Validation Loss: 0.000097


Epoch 3:   0%|          | 1/919 [00:00<04:28,  3.42it/s]

Epoch 3, Step 2, Training Loss: 0.000000


Epoch 3:   4%|▍         | 41/919 [00:08<03:54,  3.75it/s]

Epoch 3, Step 42, Training Loss: 0.000039


Epoch 3:   9%|▉         | 81/919 [00:20<04:08,  3.37it/s]

Epoch 3, Step 82, Training Loss: 0.000052


Epoch 3:  13%|█▎        | 121/919 [00:32<03:53,  3.42it/s]

Epoch 3, Step 122, Training Loss: 0.000041


Epoch 3:  18%|█▊        | 161/919 [00:41<04:06,  3.07it/s]

Epoch 3, Step 162, Training Loss: 0.000051


Epoch 3:  22%|██▏       | 201/919 [00:53<03:34,  3.35it/s]

Epoch 3, Step 202, Training Loss: 0.000048


Epoch 3:  26%|██▌       | 240/919 [01:04<03:04,  3.68it/s]

Epoch 3, Step 242, Training Loss: 0.000053


Epoch 3:  31%|███       | 281/919 [01:12<02:54,  3.66it/s]

Epoch 3, Step 282, Training Loss: 0.000049


Epoch 3:  35%|███▍      | 321/919 [01:23<02:42,  3.68it/s]

Epoch 3, Step 322, Training Loss: 0.000047


Epoch 3:  39%|███▉      | 360/919 [01:33<02:32,  3.67it/s]

Epoch 3, Step 362, Training Loss: 0.000047


Epoch 3:  44%|████▎     | 401/919 [01:41<02:16,  3.79it/s]

Epoch 3, Step 402, Training Loss: 0.000049


Epoch 3:  48%|████▊     | 441/919 [01:52<02:09,  3.68it/s]

Epoch 3, Step 442, Training Loss: 0.000048


Epoch 3:  52%|█████▏    | 482/919 [02:03<01:53,  3.85it/s]

Epoch 3, Step 482, Training Loss: 0.000045


Epoch 3:  57%|█████▋    | 521/919 [02:10<01:42,  3.89it/s]

Epoch 3, Step 522, Training Loss: 0.000048


Epoch 3:  61%|██████    | 561/919 [02:21<01:34,  3.80it/s]

Epoch 3, Step 562, Training Loss: 0.000050


Epoch 3:  65%|██████▌   | 601/919 [02:32<01:24,  3.76it/s]

Epoch 3, Step 602, Training Loss: 0.000049


Epoch 3:  70%|██████▉   | 641/919 [02:40<01:13,  3.79it/s]

Epoch 3, Step 642, Training Loss: 0.000046


Epoch 3:  74%|███████▍  | 681/919 [02:50<01:03,  3.74it/s]

Epoch 3, Step 682, Training Loss: 0.000046


Epoch 3:  78%|███████▊  | 721/919 [03:01<00:54,  3.65it/s]

Epoch 3, Step 722, Training Loss: 0.000046


Epoch 3:  83%|████████▎ | 761/919 [03:09<00:41,  3.81it/s]

Epoch 3, Step 762, Training Loss: 0.000048


Epoch 3:  87%|████████▋ | 801/919 [03:20<00:31,  3.76it/s]

Epoch 3, Step 802, Training Loss: 0.000047


Epoch 3:  92%|█████████▏| 841/919 [03:30<00:20,  3.72it/s]

Epoch 3, Step 842, Training Loss: 0.000048


Epoch 3:  96%|█████████▌| 881/919 [03:38<00:09,  3.81it/s]

Epoch 3, Step 882, Training Loss: 0.000048


Epoch 3: 100%|██████████| 919/919 [03:48<00:00,  4.02it/s]
Validation Epoch 3: 100%|██████████| 115/115 [00:07<00:00, 15.85it/s]


(459, 128, 2)
Best Threshold: 0.154, G-Mean: 0.923
Shape of flattened labels: (58752,)
Shape of flattened predictions: (58752,)
              precision    recall  f1-score   support

           0       1.00      0.98      0.99     53478
           1       0.39      0.87      0.53       689

    accuracy                           0.98     54167
   macro avg       0.69      0.92      0.76     54167
weighted avg       0.99      0.98      0.98     54167

Validation Loss: 0.000076


Epoch 4:   0%|          | 2/919 [00:00<04:25,  3.45it/s]

Epoch 4, Step 3, Training Loss: 0.000182


Epoch 4:   5%|▍         | 42/919 [00:08<02:41,  5.44it/s]

Epoch 4, Step 43, Training Loss: 0.000032


Epoch 4:   9%|▉         | 82/919 [00:18<03:45,  3.71it/s]

Epoch 4, Step 83, Training Loss: 0.000026


Epoch 4:  13%|█▎        | 122/919 [00:29<03:39,  3.64it/s]

Epoch 4, Step 123, Training Loss: 0.000021


Epoch 4:  18%|█▊        | 161/919 [00:37<01:59,  6.35it/s]

Epoch 4, Step 163, Training Loss: 0.000027


Epoch 4:  22%|██▏       | 202/919 [00:48<03:17,  3.63it/s]

Epoch 4, Step 203, Training Loss: 0.000042


Epoch 4:  26%|██▋       | 242/919 [00:59<03:14,  3.48it/s]

Epoch 4, Step 243, Training Loss: 0.000042


Epoch 4:  31%|███       | 282/919 [01:08<02:22,  4.46it/s]

Epoch 4, Step 283, Training Loss: 0.000041


Epoch 4:  35%|███▌      | 322/919 [01:19<02:58,  3.35it/s]

Epoch 4, Step 323, Training Loss: 0.000038


Epoch 4:  39%|███▉      | 362/919 [01:30<02:43,  3.41it/s]

Epoch 4, Step 363, Training Loss: 0.000036


Epoch 4:  44%|████▎     | 402/919 [01:39<02:29,  3.45it/s]

Epoch 4, Step 403, Training Loss: 0.000035


Epoch 4:  48%|████▊     | 442/919 [01:50<02:08,  3.70it/s]

Epoch 4, Step 443, Training Loss: 0.000038


Epoch 4:  52%|█████▏    | 482/919 [02:01<02:01,  3.59it/s]

Epoch 4, Step 483, Training Loss: 0.000037


Epoch 4:  57%|█████▋    | 522/919 [02:09<01:49,  3.64it/s]

Epoch 4, Step 523, Training Loss: 0.000036


Epoch 4:  61%|██████    | 562/919 [02:20<01:36,  3.68it/s]

Epoch 4, Step 563, Training Loss: 0.000036


Epoch 4:  66%|██████▌   | 602/919 [02:31<01:26,  3.68it/s]

Epoch 4, Step 603, Training Loss: 0.000036


Epoch 4:  70%|██████▉   | 642/919 [02:39<01:19,  3.50it/s]

Epoch 4, Step 643, Training Loss: 0.000037


Epoch 4:  74%|███████▍  | 682/919 [02:50<01:04,  3.65it/s]

Epoch 4, Step 683, Training Loss: 0.000036


Epoch 4:  79%|███████▊  | 722/919 [03:01<00:52,  3.72it/s]

Epoch 4, Step 723, Training Loss: 0.000035


Epoch 4:  83%|████████▎ | 762/919 [03:08<00:45,  3.47it/s]

Epoch 4, Step 763, Training Loss: 0.000035


Epoch 4:  87%|████████▋ | 802/919 [03:19<00:33,  3.49it/s]

Epoch 4, Step 803, Training Loss: 0.000034


Epoch 4:  92%|█████████▏| 842/919 [03:30<00:21,  3.56it/s]

Epoch 4, Step 843, Training Loss: 0.000034


Epoch 4:  96%|█████████▌| 882/919 [03:38<00:09,  3.75it/s]

Epoch 4, Step 883, Training Loss: 0.000033


Epoch 4: 100%|██████████| 919/919 [03:48<00:00,  4.03it/s]
Validation Epoch 4: 100%|██████████| 115/115 [00:07<00:00, 16.31it/s]


(459, 128, 2)
Best Threshold: 0.117, G-Mean: 0.928
Shape of flattened labels: (58752,)
Shape of flattened predictions: (58752,)
              precision    recall  f1-score   support

           0       1.00      0.97      0.99     53478
           1       0.30      0.88      0.45       689

    accuracy                           0.97     54167
   macro avg       0.65      0.93      0.72     54167
weighted avg       0.99      0.97      0.98     54167

Validation Loss: 0.000062


Epoch 5:   0%|          | 3/919 [00:00<03:56,  3.87it/s]

Epoch 5, Step 4, Training Loss: 0.000002


Epoch 5:   5%|▍         | 43/919 [00:08<02:11,  6.64it/s]

Epoch 5, Step 44, Training Loss: 0.000022


Epoch 5:   9%|▉         | 83/919 [00:19<03:48,  3.66it/s]

Epoch 5, Step 84, Training Loss: 0.000019


Epoch 5:  13%|█▎        | 123/919 [00:30<03:40,  3.61it/s]

Epoch 5, Step 124, Training Loss: 0.000029


Epoch 5:  18%|█▊        | 163/919 [00:38<02:02,  6.15it/s]

Epoch 5, Step 164, Training Loss: 0.000025


Epoch 5:  22%|██▏       | 203/919 [00:49<03:16,  3.65it/s]

Epoch 5, Step 204, Training Loss: 0.000021


Epoch 5:  26%|██▋       | 243/919 [01:00<02:59,  3.76it/s]

Epoch 5, Step 244, Training Loss: 0.000021


Epoch 5:  31%|███       | 283/919 [01:08<02:06,  5.04it/s]

Epoch 5, Step 284, Training Loss: 0.000020


Epoch 5:  35%|███▌      | 323/919 [01:18<02:38,  3.76it/s]

Epoch 5, Step 324, Training Loss: 0.000018


Epoch 5:  39%|███▉      | 363/919 [01:29<02:42,  3.42it/s]

Epoch 5, Step 364, Training Loss: 0.000019


Epoch 5:  44%|████▍     | 403/919 [01:38<01:45,  4.87it/s]

Epoch 5, Step 404, Training Loss: 0.000020


Epoch 5:  48%|████▊     | 443/919 [01:49<02:14,  3.55it/s]

Epoch 5, Step 444, Training Loss: 0.000019


Epoch 5:  53%|█████▎    | 483/919 [02:01<02:03,  3.53it/s]

Epoch 5, Step 484, Training Loss: 0.000020


Epoch 5:  57%|█████▋    | 523/919 [02:09<01:45,  3.75it/s]

Epoch 5, Step 524, Training Loss: 0.000019


Epoch 5:  61%|██████▏   | 563/919 [02:20<01:40,  3.53it/s]

Epoch 5, Step 564, Training Loss: 0.000020


Epoch 5:  66%|██████▌   | 603/919 [02:32<01:27,  3.59it/s]

Epoch 5, Step 604, Training Loss: 0.000019


Epoch 5:  70%|██████▉   | 643/919 [02:40<01:18,  3.52it/s]

Epoch 5, Step 644, Training Loss: 0.000020


Epoch 5:  74%|███████▍  | 683/919 [02:51<01:07,  3.52it/s]

Epoch 5, Step 684, Training Loss: 0.000020


Epoch 5:  79%|███████▊  | 723/919 [03:03<00:55,  3.54it/s]

Epoch 5, Step 724, Training Loss: 0.000020


Epoch 5:  83%|████████▎ | 763/919 [03:11<00:43,  3.61it/s]

Epoch 5, Step 764, Training Loss: 0.000019


Epoch 5:  87%|████████▋ | 803/919 [03:22<00:33,  3.50it/s]

Epoch 5, Step 804, Training Loss: 0.000019


Epoch 5:  92%|█████████▏| 842/919 [03:33<00:21,  3.56it/s]

Epoch 5, Step 844, Training Loss: 0.000019


Epoch 5:  96%|█████████▌| 883/919 [03:42<00:10,  3.55it/s]

Epoch 5, Step 884, Training Loss: 0.000019


Epoch 5: 100%|██████████| 919/919 [03:52<00:00,  3.95it/s]
Validation Epoch 5: 100%|██████████| 115/115 [00:07<00:00, 15.52it/s]


(459, 128, 2)
Best Threshold: 0.157, G-Mean: 0.935
Shape of flattened labels: (58752,)
Shape of flattened predictions: (58752,)
              precision    recall  f1-score   support

           0       1.00      0.98      0.99     53478
           1       0.38      0.89      0.54       689

    accuracy                           0.98     54167
   macro avg       0.69      0.93      0.76     54167
weighted avg       0.99      0.98      0.98     54167

Validation Loss: 0.000049


Epoch 6:   0%|          | 4/919 [00:01<04:17,  3.55it/s]

Epoch 6, Step 5, Training Loss: 0.000049


Epoch 6:   5%|▍         | 44/919 [00:09<03:57,  3.69it/s]

Epoch 6, Step 45, Training Loss: 0.000012


Epoch 6:   9%|▉         | 84/919 [00:20<03:56,  3.54it/s]

Epoch 6, Step 85, Training Loss: 0.000013


Epoch 6:  13%|█▎        | 124/919 [00:32<03:44,  3.54it/s]

Epoch 6, Step 125, Training Loss: 0.000011


Epoch 6:  18%|█▊        | 164/919 [00:40<03:35,  3.51it/s]

Epoch 6, Step 165, Training Loss: 0.000010


Epoch 6:  22%|██▏       | 204/919 [00:51<03:25,  3.48it/s]

Epoch 6, Step 205, Training Loss: 0.000010


Epoch 6:  27%|██▋       | 244/919 [01:03<03:11,  3.53it/s]

Epoch 6, Step 245, Training Loss: 0.000011


Epoch 6:  31%|███       | 284/919 [01:11<02:58,  3.55it/s]

Epoch 6, Step 285, Training Loss: 0.000010


Epoch 6:  35%|███▌      | 324/919 [01:22<02:47,  3.56it/s]

Epoch 6, Step 325, Training Loss: 0.000010


Epoch 6:  40%|███▉      | 364/919 [01:34<02:34,  3.59it/s]

Epoch 6, Step 365, Training Loss: 0.000010


Epoch 6:  44%|████▍     | 404/919 [01:42<02:27,  3.50it/s]

Epoch 6, Step 405, Training Loss: 0.000010


Epoch 6:  48%|████▊     | 444/919 [01:53<02:12,  3.57it/s]

Epoch 6, Step 445, Training Loss: 0.000010


Epoch 6:  53%|█████▎    | 484/919 [02:04<02:01,  3.59it/s]

Epoch 6, Step 485, Training Loss: 0.000009


Epoch 6:  57%|█████▋    | 524/919 [02:13<01:53,  3.49it/s]

Epoch 6, Step 525, Training Loss: 0.000009


Epoch 6:  61%|██████▏   | 564/919 [02:24<01:40,  3.52it/s]

Epoch 6, Step 565, Training Loss: 0.000009


Epoch 6:  65%|██████▌   | 598/919 [02:34<01:30,  3.54it/s]

Epoch 6, Step 605, Training Loss: 0.000010


Epoch 6:  70%|███████   | 644/919 [02:44<01:17,  3.55it/s]

Epoch 6, Step 645, Training Loss: 0.000011


Epoch 6:  74%|███████▍  | 684/919 [02:55<01:05,  3.57it/s]

Epoch 6, Step 685, Training Loss: 0.000012


Epoch 6:  79%|███████▊  | 723/919 [03:03<00:11, 16.77it/s]

Epoch 6, Step 725, Training Loss: 0.000012


Epoch 6:  83%|████████▎ | 764/919 [03:15<00:44,  3.49it/s]

Epoch 6, Step 765, Training Loss: 0.000014


Epoch 6:  87%|████████▋ | 804/919 [03:26<00:31,  3.61it/s]

Epoch 6, Step 805, Training Loss: 0.000014


Epoch 6:  92%|█████████▏| 843/919 [03:34<00:09,  8.05it/s]

Epoch 6, Step 845, Training Loss: 0.000014


Epoch 6:  96%|█████████▌| 884/919 [03:46<00:09,  3.55it/s]

Epoch 6, Step 885, Training Loss: 0.000014


Epoch 6: 100%|██████████| 919/919 [03:56<00:00,  3.89it/s]
Validation Epoch 6: 100%|██████████| 115/115 [00:04<00:00, 25.88it/s]


(459, 128, 2)
Best Threshold: 0.146, G-Mean: 0.936
Shape of flattened labels: (58752,)
Shape of flattened predictions: (58752,)
              precision    recall  f1-score   support

           0       1.00      0.98      0.99     53478
           1       0.41      0.89      0.56       689

    accuracy                           0.98     54167
   macro avg       0.70      0.94      0.77     54167
weighted avg       0.99      0.98      0.99     54167

Validation Loss: 0.000113


Epoch 7:   1%|          | 5/919 [00:01<04:15,  3.58it/s]

Epoch 7, Step 6, Training Loss: 0.000010


Epoch 7:   5%|▍         | 45/919 [00:12<04:07,  3.53it/s]

Epoch 7, Step 46, Training Loss: 0.000014


Epoch 7:   9%|▉         | 85/919 [00:24<03:57,  3.52it/s]

Epoch 7, Step 86, Training Loss: 0.000010


Epoch 7:  13%|█▎        | 123/919 [00:31<00:48, 16.41it/s]

Epoch 7, Step 126, Training Loss: 0.000008


Epoch 7:  18%|█▊        | 165/919 [00:43<03:32,  3.55it/s]

Epoch 7, Step 166, Training Loss: 0.000008


Epoch 7:  22%|██▏       | 205/919 [00:55<03:20,  3.55it/s]

Epoch 7, Step 206, Training Loss: 0.000008


Epoch 7:  27%|██▋       | 245/919 [01:03<01:55,  5.84it/s]

Epoch 7, Step 246, Training Loss: 0.000009


Epoch 7:  31%|███       | 285/919 [01:14<03:01,  3.49it/s]

Epoch 7, Step 286, Training Loss: 0.000009


Epoch 7:  35%|███▌      | 325/919 [01:26<02:50,  3.48it/s]

Epoch 7, Step 326, Training Loss: 0.000009


Epoch 7:  40%|███▉      | 365/919 [01:34<02:10,  4.26it/s]

Epoch 7, Step 366, Training Loss: 0.000008


Epoch 7:  44%|████▍     | 405/919 [01:45<02:27,  3.49it/s]

Epoch 7, Step 406, Training Loss: 0.000008


Epoch 7:  48%|████▊     | 445/919 [01:57<02:16,  3.46it/s]

Epoch 7, Step 446, Training Loss: 0.000008


Epoch 7:  53%|█████▎    | 485/919 [02:05<01:59,  3.63it/s]

Epoch 7, Step 486, Training Loss: 0.000010


Epoch 7:  57%|█████▋    | 525/919 [02:16<01:52,  3.50it/s]

Epoch 7, Step 526, Training Loss: 0.000010


Epoch 7:  61%|██████▏   | 565/919 [02:28<01:40,  3.51it/s]

Epoch 7, Step 566, Training Loss: 0.000010


Epoch 7:  66%|██████▌   | 605/919 [02:36<01:28,  3.55it/s]

Epoch 7, Step 606, Training Loss: 0.000010


Epoch 7:  70%|███████   | 645/919 [02:48<01:18,  3.47it/s]

Epoch 7, Step 646, Training Loss: 0.000010


Epoch 7:  75%|███████▍  | 685/919 [02:59<01:08,  3.44it/s]

Epoch 7, Step 686, Training Loss: 0.000009


Epoch 7:  79%|███████▉  | 725/919 [03:07<00:54,  3.54it/s]

Epoch 7, Step 726, Training Loss: 0.000010


Epoch 7:  83%|████████▎ | 765/919 [03:19<00:43,  3.51it/s]

Epoch 7, Step 766, Training Loss: 0.000010


Epoch 7:  87%|████████▋ | 801/919 [03:29<00:34,  3.45it/s]

Epoch 7, Step 806, Training Loss: 0.000010


Epoch 7:  92%|█████████▏| 845/919 [03:38<00:21,  3.50it/s]

Epoch 7, Step 846, Training Loss: 0.000009


Epoch 7:  96%|█████████▋| 885/919 [03:50<00:09,  3.59it/s]

Epoch 7, Step 886, Training Loss: 0.000009


Epoch 7: 100%|██████████| 919/919 [03:56<00:00,  3.88it/s]
Validation Epoch 7: 100%|██████████| 115/115 [00:07<00:00, 15.44it/s]


(459, 128, 2)
Best Threshold: 0.080, G-Mean: 0.935
Shape of flattened labels: (58752,)
Shape of flattened predictions: (58752,)
              precision    recall  f1-score   support

           0       1.00      0.96      0.98     53478
           1       0.21      0.91      0.34       689

    accuracy                           0.96     54167
   macro avg       0.60      0.93      0.66     54167
weighted avg       0.99      0.96      0.97     54167

Validation Loss: 0.000074


Epoch 8:   1%|          | 6/919 [00:01<04:09,  3.65it/s]

Epoch 8, Step 7, Training Loss: 0.000000


Epoch 8:   5%|▌         | 46/919 [00:12<04:06,  3.54it/s]

Epoch 8, Step 47, Training Loss: 0.000006


Epoch 8:   9%|▉         | 85/919 [00:23<03:54,  3.55it/s]

Epoch 8, Step 87, Training Loss: 0.000005


Epoch 8:  14%|█▎        | 126/919 [00:32<03:43,  3.55it/s]

Epoch 8, Step 127, Training Loss: 0.000005


Epoch 8:  18%|█▊        | 166/919 [00:43<03:31,  3.56it/s]

Epoch 8, Step 167, Training Loss: 0.000006


Epoch 8:  22%|██▏       | 200/919 [00:53<03:26,  3.48it/s]

Epoch 8, Step 207, Training Loss: 0.000006


Epoch 8:  27%|██▋       | 246/919 [01:03<03:12,  3.50it/s]

Epoch 8, Step 247, Training Loss: 0.000010


Epoch 8:  31%|███       | 286/919 [01:14<03:00,  3.50it/s]

Epoch 8, Step 287, Training Loss: 0.000009


Epoch 8:  35%|███▌      | 326/919 [01:23<00:36, 16.28it/s]

Epoch 8, Step 327, Training Loss: 0.000009


Epoch 8:  40%|███▉      | 366/919 [01:34<02:38,  3.48it/s]

Epoch 8, Step 367, Training Loss: 0.000008


Epoch 8:  44%|████▍     | 406/919 [01:46<02:25,  3.52it/s]

Epoch 8, Step 407, Training Loss: 0.000008


Epoch 8:  48%|████▊     | 445/919 [01:54<01:08,  6.87it/s]

Epoch 8, Step 447, Training Loss: 0.000007


Epoch 8:  53%|█████▎    | 486/919 [02:05<02:03,  3.50it/s]

Epoch 8, Step 487, Training Loss: 0.000007


Epoch 8:  57%|█████▋    | 526/919 [02:17<01:52,  3.51it/s]

Epoch 8, Step 527, Training Loss: 0.000007


Epoch 8:  62%|██████▏   | 566/919 [02:25<01:21,  4.32it/s]

Epoch 8, Step 567, Training Loss: 0.000007


Epoch 8:  66%|██████▌   | 606/919 [02:36<01:29,  3.48it/s]

Epoch 8, Step 607, Training Loss: 0.000007


Epoch 8:  70%|███████   | 646/919 [02:48<01:15,  3.63it/s]

Epoch 8, Step 647, Training Loss: 0.000007


Epoch 8:  75%|███████▍  | 686/919 [02:56<01:04,  3.60it/s]

Epoch 8, Step 687, Training Loss: 0.000007


Epoch 8:  79%|███████▉  | 726/919 [03:07<00:55,  3.50it/s]

Epoch 8, Step 727, Training Loss: 0.000007


Epoch 8:  83%|████████▎ | 766/919 [03:19<00:44,  3.42it/s]

Epoch 8, Step 767, Training Loss: 0.000007


Epoch 8:  88%|████████▊ | 806/919 [03:27<00:32,  3.53it/s]

Epoch 8, Step 807, Training Loss: 0.000007


Epoch 8:  92%|█████████▏| 846/919 [03:39<00:21,  3.39it/s]

Epoch 8, Step 847, Training Loss: 0.000007


Epoch 8:  96%|█████████▋| 886/919 [03:50<00:09,  3.45it/s]

Epoch 8, Step 887, Training Loss: 0.000006


Epoch 8: 100%|██████████| 919/919 [03:56<00:00,  3.88it/s]
Validation Epoch 8: 100%|██████████| 115/115 [00:07<00:00, 15.51it/s]


(459, 128, 2)
Best Threshold: 0.090, G-Mean: 0.939
Shape of flattened labels: (58752,)
Shape of flattened predictions: (58752,)
              precision    recall  f1-score   support

           0       1.00      0.97      0.98     53478
           1       0.29      0.91      0.44       689

    accuracy                           0.97     54167
   macro avg       0.64      0.94      0.71     54167
weighted avg       0.99      0.97      0.98     54167

Validation Loss: 0.000069
Early stopping triggered after 8 epochs!
Loaded best model weights!
[2025-05-17 20:09:17]End Training
[2025-05-17 20:09:17]Total Training Time:1928.83962059021 seconds


In [10]:
test_loader = DataLoader(test_dataset, batch_size = 4, shuffle=False)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
total_eval_loss = 0
all_val_logits = []
all_val_labels = []
with torch.no_grad():
    for batch in tqdm(test_loader):
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        _, logits = model(**inputs)
        all_val_logits.extend(logits.cpu().numpy())
        all_val_labels.extend(labels.cpu().numpy())

custom_classification_report(all_val_logits, all_val_labels, True)
torch.cuda.empty_cache()

100%|██████████| 115/115 [00:03<00:00, 32.21it/s]

(460, 128, 2)
Best Threshold: 0.114, G-Mean: 0.951
Shape of flattened labels: (58880,)
Shape of flattened predictions: (58880,)
              precision    recall  f1-score   support

           0       1.00      0.99      0.99     54962
           1       0.43      0.92      0.58       650

    accuracy                           0.98     55612
   macro avg       0.71      0.95      0.79     55612
weighted avg       0.99      0.98      0.99     55612




