In [1]:
import data_preprocess

data = data_preprocess.run()

sequences = []
labels = []

MAX_LEN = 128

for item in data:
    while len(item['sequence']) > MAX_LEN:
        sequences.append(item['sequence'][:MAX_LEN])
        labels.append(item['label'][:MAX_LEN])
        item['sequence'] = item['sequence'][MAX_LEN:]
        item['label'] = item['label'][MAX_LEN:]
    sequences.append(item['sequence'])
    labels.append(item['label'])

In [2]:
import pandas as pd

zero_counter = 0
ones_counter = 0

for l in labels:
    zero_counter+=l.count(0)
    ones_counter+=l.count(1)

pd.DataFrame(columns=['sample amount'], data = [zero_counter, ones_counter])

Unnamed: 0,sample amount
0,35045
1,1131


In [3]:
from dataset import FADBindingDataset, split_dataset
from esm_model import EsmForSequenceLabeling
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForSequenceLabeling("facebook/esm2_t6_8M_UR50D", 2)
dataset = FADBindingDataset(tokenizer, sequences, labels, max_length=MAX_LEN)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset[0]

{'input_ids': tensor([ 0, 20, 15, 12, 10, 18,  7, 18, 12,  4,  8,  7,  4, 12,  8,  6,  7, 23,
         23, 12,  8, 15, 17,  7,  8, 10, 10,  7,  5, 17, 10, 20, 11,  5, 21,  8,
         10, 18,  4, 18,  7, 21, 13, 15, 19, 15, 10, 17, 15, 17, 18, 15,  4, 15,
         17, 17, 15,  9,  9, 17, 17, 18, 12, 17,  4, 19, 11,  7, 15, 17, 14,  4,
         15, 23, 15, 12,  7, 13, 15, 12, 17,  4,  7, 10, 14, 17,  8, 14, 17,  9,
          7, 19, 21,  4,  9, 12, 17, 21, 17,  6,  4, 18, 15, 19,  4,  9,  6, 21,
         11, 23,  6, 12, 12, 14, 19, 19, 17,  9,  4, 13, 17, 17, 14, 17, 17, 16,
         12,  2]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [5]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from transformers import get_linear_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter

[train_dataset, test_dataset] = split_dataset(dataset, 0.8)
[validate_dataset, test_dataset] = split_dataset(test_dataset, 0.5)


In [6]:
import numpy as np
from sklearn.metrics import classification_report, roc_curve

def find_best_evalution(logits, labels):
    logits_np = np.array(logits)
    labels_np = np.array(labels)
    #sigmoid probability
    print(logits_np.shape)
    probs = 1 / (1 + np.exp(-logits_np[:, :, 1]))#[batch size, sequence length]

    labels_flat = labels_np.flatten()
    probs_flat = probs.flatten()

     # 只取非 padding (-100) 的有效樣本
    active_indices = labels_flat != -100
    final_labels = labels_flat[active_indices]
    final_probs = probs_flat[active_indices]

    # 計算 ROC 並找出最佳 G-mean 閾值
    fpr, tpr, thresholds = roc_curve(final_labels, final_probs)
    gmeans = np.sqrt(tpr * (1 - fpr))
    ix = np.argmax(gmeans)
    best_threshold = thresholds[ix]

    print(f'Best Threshold: {best_threshold:.3f}, G-Mean: {gmeans[ix]:.3f}')
    return best_threshold
    

def custom_classification_report(logits, labels, test_evalution = False):
    """Filt the ignore value and return classification report"""
    labels_np = np.array(labels)
    logits_np = np.array(logits)
    #preds = np.argmax(logits_np, axis=-1).flatten()
    if test_evalution:
        threshold = find_best_evalution(logits, labels)
    else:
        threshold = 0.5
    probs = 1/(1+np.exp(-logits_np[:,:,1]))
    preds = (probs > threshold).astype(int).flatten()
    labels_flat = labels_np.flatten()

    print("Shape of flattened labels:", labels_flat.shape)
    print("Shape of flattened predictions:", preds.shape)

    active_indices = labels_flat != -100
    final_labels = labels_flat[active_indices]
    final_preds = preds[active_indices]
    report_str = classification_report(final_labels, final_preds)
    report_dict = classification_report(final_labels, final_preds, output_dict=True)
    print(report_str)
    return report_dict

In [7]:
from tqdm import tqdm
import yaml

def train_by_dataloader():

    writer = SummaryWriter('runs/my_experiment') # TensorBoard writer

    ## hyperparameters in yaml file
    with open("../config/mamba_config.yaml", "r") as f:
        config = yaml.safe_load(f)
    learning_rate = config["training_arguments"]["learning_rate"]
    per_device_train_batch_size = config["training_arguments"]["per_device_train_batch_size"]
    num_epochs = config["training_arguments"]["num_train_epochs"]
    warmup_steps = config['training_arguments']['warmup_steps']
    log_step = config['training_arguments']['eval_steps']
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    ##

    ## create dataloader
    train_loader = DataLoader(train_dataset, per_device_train_batch_size, shuffle=True)
    validation_loader = DataLoader(validate_dataset, per_device_train_batch_size, shuffle=True)
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    ## 

    ## some variable to use
    patience = 5
    best_val_loss = float('inf')
    best_f1_score = -float('inf')
    epochs_no_improve = 0
    best_model_state_dict = None
    global_step = 0
    ## 

    ## move model to GPU
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    ##
  
    for epoch in range(num_epochs):
        """Training Loop Start here"""
        model.train()  
        total_train_loss = 0  
            
        for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):          
            inputs = {k:v.to(device) for k,v in batch.items() if k != 'labels'}# 將批次數據移至設備  
            labels = batch['labels'].to(device)

            optimizer.zero_grad()# 清除梯度
            loss, _ = model(**inputs, labels=labels)  # 前向傳播  
            loss.backward()  # 反向傳播和優化  
            optimizer.step()
            scheduler.step()
            global_step+=1

            if global_step % log_step == 0:
                avg_train_loss = total_train_loss / (step + 1)
                writer.add_scalar('Loss/train', avg_train_loss, global_step)
                print(f"Epoch {epoch+1}, Step {step+1}, Training Loss: {avg_train_loss:.6f}")
                
            total_train_loss += loss.item()
        """Training Loop End here"""

        """Validate Loop Start here"""
        model.eval()
        total_eval_loss = 0
        all_val_logits = []
        all_val_labels = []
        with torch.no_grad():
            for batch in tqdm(validation_loader, desc=f"Validation Epoch {epoch+1}"):
                inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
                labels = batch['labels'].to(device)

                loss, logits = model(**inputs, labels=labels)

                total_eval_loss += loss.item()
                all_val_logits.extend(logits.cpu().numpy())
                all_val_labels.extend(labels.cpu().numpy())
        """Validate Loop End here"""

        """Report and Save the weight of the model"""
        report = custom_classification_report(all_val_logits, all_val_labels, True)

        avg_eval_loss = total_eval_loss / len(validation_loader)
        val_f1_score = report['1']['f1-score']##means Marco-f1-score on 1's sample prediction
        print(f"Validation Loss: {avg_eval_loss:.6f}")

        # Store the model has best performance on the specific indicator like: 
        # evaluation loss, recall rate or f1-score
        ## early stopping part
        if avg_eval_loss < best_val_loss:
            best_val_loss = avg_eval_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs!")
                break

        ## save weight part
        if val_f1_score > best_f1_score:
            best_f1_score = val_f1_score
            best_model_state_dict = model.state_dict() # 保存最佳權重
        """An epoch end here"""

        writer.add_scalar('Loss/val', avg_eval_loss, epoch + 1)
    writer.close()

    if best_model_state_dict is not None:
        model.load_state_dict(best_model_state_dict)
        print("Loaded best model weights!")
    else:
        print("No improvement found during training.")
    # 保存微調後的模型  
    torch.save(model.state_dict(), "esm_fad_binding_model.pt")

In [1]:
import time

st = time.time()
time.sleep(3)
et = time.time()
print(et - st)

3.0031516551971436


In [None]:
from utils import helpers


start_time = time.time()
helpers.log("Start Training")
trainer = train_by_dataloader()
helpers.log("End Training")
end_time = time.time()
helpers.log(f"Total Training Time:{end_time - start_time} second")

Epoch 1:  63%|██████▎   | 40/63 [00:07<00:03,  6.72it/s]

Epoch 1, Step 40, Training Loss: 0.163316


Epoch 1: 100%|██████████| 63/63 [00:10<00:00,  5.89it/s]
Validation Epoch 1: 100%|██████████| 8/8 [00:00<00:00, 51.61it/s]


(32, 128, 2)
Best Threshold: 0.441, G-Mean: 0.541
Shape of flattened labels: (4096,)
Shape of flattened predictions: (4096,)
              precision    recall  f1-score   support

           0       0.95      0.66      0.78      3660
           1       0.07      0.44      0.12       205

    accuracy                           0.65      3865
   macro avg       0.51      0.55      0.45      3865
weighted avg       0.91      0.65      0.74      3865

Validation Loss: 0.142362


Epoch 2:  27%|██▋       | 17/63 [00:02<00:06,  6.81it/s]

Epoch 2, Step 17, Training Loss: 0.125025


Epoch 2:  90%|█████████ | 57/63 [00:08<00:00,  7.13it/s]

Epoch 2, Step 57, Training Loss: 0.107299


Epoch 2: 100%|██████████| 63/63 [00:09<00:00,  6.86it/s]
Validation Epoch 2: 100%|██████████| 8/8 [00:00<00:00, 53.99it/s]


(32, 128, 2)
Best Threshold: 0.387, G-Mean: 0.578
Shape of flattened labels: (4096,)
Shape of flattened predictions: (4096,)
              precision    recall  f1-score   support

           0       0.96      0.64      0.77      3660
           1       0.07      0.52      0.13       205

    accuracy                           0.63      3865
   macro avg       0.52      0.58      0.45      3865
weighted avg       0.91      0.63      0.73      3865

Validation Loss: 0.070801


Epoch 3:  54%|█████▍    | 34/63 [00:04<00:04,  6.97it/s]

Epoch 3, Step 34, Training Loss: 0.043888


Epoch 3: 100%|██████████| 63/63 [00:09<00:00,  6.93it/s]
Validation Epoch 3: 100%|██████████| 8/8 [00:00<00:00, 50.91it/s]


(32, 128, 2)
Best Threshold: 0.230, G-Mean: 0.481
Shape of flattened labels: (4096,)
Shape of flattened predictions: (4096,)
              precision    recall  f1-score   support

           0       0.94      0.39      0.55      3660
           1       0.05      0.59      0.09       205

    accuracy                           0.40      3865
   macro avg       0.50      0.49      0.32      3865
weighted avg       0.90      0.40      0.53      3865

Validation Loss: 0.017190


Epoch 4:  17%|█▋        | 11/63 [00:01<00:07,  6.90it/s]

Epoch 4, Step 11, Training Loss: 0.009927


Epoch 4:  33%|███▎      | 21/63 [00:03<00:06,  6.80it/s]


KeyboardInterrupt: 

In [None]:
validation_loader = DataLoader(validate_dataset, batch_size = 4, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size = 4, shuffle=False)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
total_eval_loss = 0
all_val_logits = []
all_val_labels = []
with torch.no_grad():
    for batch in tqdm(validation_loader):
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        _, logits = model(**inputs)
        all_val_logits.extend(logits.cpu().numpy())
        all_val_labels.extend(labels.cpu().numpy())

custom_classification_report(all_val_logits, all_val_labels, True)
all_val_logits = []
all_val_labels = []
with torch.no_grad():
    for batch in tqdm(test_loader):
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        _, logits = model(**inputs)
        all_val_logits.extend(logits.cpu().numpy())
        all_val_labels.extend(labels.cpu().numpy())

custom_classification_report(all_val_logits, all_val_labels, True)
torch.cuda.empty_cache()

100%|██████████| 8/8 [00:00<00:00, 30.28it/s]


(32, 128, 2)
Best Threshold: 0.143, G-Mean: 0.763
Shape of flattened labels: (4096,)
Shape of flattened predictions: (4096,)
              precision    recall  f1-score   support

           0       0.99      0.83      0.90      3635
           1       0.12      0.69      0.21       124

    accuracy                           0.82      3759
   macro avg       0.55      0.76      0.55      3759
weighted avg       0.96      0.82      0.88      3759



100%|██████████| 8/8 [00:00<00:00, 128.01it/s]


(32, 128, 2)
Best Threshold: 0.214, G-Mean: 0.788
Shape of flattened labels: (4096,)
Shape of flattened predictions: (4096,)
              precision    recall  f1-score   support

           0       0.99      0.91      0.95      3713
           1       0.10      0.67      0.17        57

    accuracy                           0.90      3770
   macro avg       0.55      0.79      0.56      3770
weighted avg       0.98      0.90      0.94      3770

