In [1]:
import os
import sys
import numpy as np
import pandas as pd
from functools import partial
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns


mpl.rcParams['font.family'] = 'Arial'
mpl.rcParams['font.size'] = 20
mpl.rcParams['lines.linewidth'] = 2


from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


pwddir = os.path.dirname(os.path.abspath("."))
print("Current working directory: ", pwddir)
sys.path.append(os.path.join(pwddir, "src", "CD-GPT"))

from config import get_config
from model import CDGPTSequencePrediction, CDGPT
from tokenizer import SentencePieceTokenizer

SEED = 0
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)

Current working directory:  e:\2025Spring_CS776\project


In [2]:
cfg = get_config()
cfg.tokenizer.path = os.path.join(pwddir, "src", "CD-GPT", "checkpoints", "tokenizer.model")
cfg.model.num_classes = 1
tokenizer = SentencePieceTokenizer(cfg.tokenizer.path)
cfg.tokenizer.pad_id = tokenizer.pad
# print("Vocabulary:", tokenizer.vocab)
print("Number of tokens:", tokenizer.vocab_size)



Number of tokens: 64000


In [3]:
tokenizer.piece_to_id("ACG")

321

In [4]:
tokenizer.id_to_piece(63999)

'▁'

In [5]:
tokenizer.encode("ACGT")

tensor([63999,   321, 63980])

In [6]:
tokenizer.encode("ACTG")

tensor([63999,   507])

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


## Load Dataset

In [8]:
class PBMDataset(Dataset):
    """
    df: TF_Id, ArrayType, Sequence, Signal_Mean, ...
    We parse only sequences for X and signals for Y.
    """
    def __init__(self, df, seq_col=None, tgt_col=None, \
                transform=None, truncate=False):
        self.samples = []
        self.transform = transform
        self.seq_col = seq_col
        self.tgt_col = tgt_col
        self.df = df
        self.truncate = truncate
    def __len__(self):
        return self.df.shape[0]


    def __getitem__(self, idx):
        seq = self.df[self.seq_col].iloc[idx]

        if self.truncate:
            seq = seq[:35]
        val = self.df[self.tgt_col].iloc[idx]
        if self.transform:
            xarr = self.transform(seq)
        
        return xarr, float(val)

In [9]:
def get_dataset(df, seq_col, tgt_col, batch_size=32, ratio=0.1, \
        transform=None, truncate=False):
    train_df = df.sample(frac=1-ratio)
    val_df = df.drop(train_df.index)
    
    train_set = PBMDataset(train_df, seq_col=seq_col, 
                         tgt_col=tgt_col, transform=transform, truncate=truncate)
    val_set = PBMDataset(val_df, seq_col=seq_col,
                         tgt_col=tgt_col, transform=transform, truncate=truncate)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

## Load Model

In [10]:
model_path = os.path.join(pwddir, "src", "CD-GPT", "checkpoints", "CD-GPT-1b.pth")
assert os.path.exists(model_path)
state = torch.load(model_path, map_location="cpu")

output_head = "sequence"
# assert output_head in ('sequence', 'token', 'residuepair')
# cdgpt = CDGPTSequencePrediction(cfg)
cdgpt = CDGPT(cfg)

  state = torch.load(model_path, map_location="cpu")


number of parameters: 1059.38M


In [11]:
cdgpt.load_state_dict(state["model"], strict=False)
print(f"load checkpoint form: {model_path}")
cdgpt = cdgpt.half()

load checkpoint form: e:\2025Spring_CS776\project\src\CD-GPT\checkpoints\CD-GPT-1b.pth


In [12]:
from config.utils import configurable
class CDGPTRegression(CDGPT):

    @classmethod
    def from_config(cls, cfg):
        pad_id = cfg.tokenizer.pad_id
        num_classes = cfg.model.num_classes
        return {
            "num_classes": num_classes,
            "pad_id": pad_id,
            **super().from_config(cfg)
        }

    @configurable
    def __init__(self,
                 num_classes: int,
                 vocab_size: int,
                 max_len: int = 2048,
                 embedding_dim=2304,
                 num_layers: int = 12,
                 num_heads: int = 24,
                 bias=False,
                 eps=1e-5,
                 pad_id=2,
                 dropout=0.0):
        super().__init__(vocab_size, max_len, embedding_dim, num_layers, num_heads, bias, eps, include_head=False)
        self.num_classes = num_classes
        self.pad_id = pad_id
        self.dropout = dropout
        
        for name, param in self.named_parameters():
            param.requires_grad = False

        self.mlp_head = nn.Linear(self.embedding_dim, self.num_classes)
        

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask = None,
                pos_ids = None):
        hiddens = super().forward(input_ids, attention_mask, pos_ids)
        if self.pad_id is None:
            sequence_lengths = -1  # last token for classification or regression
        else:
            sequence_lengths = torch.ne(input_ids, self.pad_id).sum(-1) - 1
        batch_size = hiddens.shape[0]
        hiddens = hiddens[torch.arange(batch_size, device=hiddens.device), sequence_lengths]
        out = self.mlp_head(hiddens)
        return out

## Utils

In [23]:
import tqdm
def train_one_epoch(model, train_loader, optimizer, device):
    model.train()
    loss_fn = nn.MSELoss()
    total_loss = 0.0
    for batch_x, batch_y in tqdm.tqdm(train_loader):
        batch_x = batch_x.to(device).long()
        batch_y = batch_y.to(device).half()
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = loss_fn(outputs, batch_y.reshape(outputs.shape))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch_x.size(0)
    return total_loss / len(train_loader.dataset)


def evaluate(model, val_loader, device):
    model.eval()
    loss_fn = nn.MSELoss()
    total_loss = 0.0
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            batch_x = batch_x.to(device).long()
            batch_y = batch_y.to(device).half()
            outputs = model(batch_x)
            loss = loss_fn(outputs, batch_y.reshape(outputs.shape))
            total_loss += loss.item() * batch_x.size(0)
    return total_loss / len(val_loader.dataset)

In [24]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

## Training

In [25]:
tf_name = "Sp1"
df = pd.read_csv(f"./data/{tf_name}.csv")
df

Unnamed: 0,ArrayType,Sequence,Signal_Mean,log2_Signal_Mean,Bias_Removed
0,HK,AAAAAACAACAGGAGGGCATCATGGAGCTGTCCAGCCTGTGTGAAA...,2582.4406,11.335078,0.336051
1,HK,AAAAAACAGCCGGATCACAATTTTGCCGAGAGCGACCTGTGTGAAA...,4164.3662,12.024228,0.859151
2,HK,AAAAAACGTCCGGTACACCCCGTTCGGCGGCCCAGCCTGTGTGAAA...,3850.3552,11.911150,1.326913
3,HK,AAAAAACTCTAGACCTTTAGCCCATCGTTGGCCAACCTGTGTGAAA...,6228.9379,12.605002,1.591492
4,HK,AAAAAAGAACAACCGGATAACACCCTTACAGCACACCTGTGTGAAA...,5027.6406,12.295953,0.856391
...,...,...,...,...,...
80851,ME,TTTTTTGAGGCCCAATCGTTTCGGCCGTGATGCTACCTGTGTGAAA...,19893.1972,14.280060,2.200901
80852,ME,TTTTTTGTGTACAGTGCTTGAAGACTCGAGGCCGTCCTGTGTGAAA...,15324.3828,13.903635,1.264322
80853,ME,TTTTTTTATCCCCAGCTGTTGGGATTAGGTTTGGGCCTGTGTGAAA...,15385.7968,13.909405,1.597506
80854,ME,TTTTTTTGAGCCGTAATCACAGCTGTGCACAGAGCCCTGTGTGAAA...,6923.0263,12.757395,0.490344


In [None]:
n_epochs = 10
batch_size = 128
split_ratio = 0.1 # 10% for validation
lr = 1e-2

max_length = 36
cdgpt_encoder = partial(tokenizer.encode, max_length=max_length, pad=True, to_tensor=True)
fileanme = f"model_{tf_name}_lr{lr}_epochs{n_epochs}_batch_size{batch_size}_seed{SEED}"
param_file = f"cdgpt_models/{fileanme}.pt"

train_loader, val_loader = get_dataset(df, seq_col="Sequence", tgt_col="Bias_Removed", 
                                       batch_size=batch_size, ratio=split_ratio, 
                                       transform=cdgpt_encoder, truncate=True)

torch.cuda.empty_cache()
model = CDGPTRegression(cfg)
model.load_state_dict(state["model"], strict=False)
print(f"load checkpoint form: {model_path}")
model = model.half()
model = model.to(device)

print_trainable_parameters(model)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)



number of parameters: 911.93M
load checkpoint form: e:\2025Spring_CS776\project\src\CD-GPT\checkpoints\CD-GPT-1b.pth
trainable params: 2305 || all params: 911927809 || trainable%: 0.0002527612358403252


In [None]:
if not os.path.exists(param_file):
    train_losses = []
    val_losses = []
    for epoch in range(n_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        val_loss = evaluate(model, val_loader, device)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    torch.save({
        "model_state_dict":model.state_dict(),
        "train_losses":train_losses,
        "val_losses":val_losses,
    }, param_file)
else:
    checkpoint = torch.load(param_file)
    model.load_state_dict(checkpoint["model_state_dict"])
    train_losses = checkpoint["train_losses"]
    val_losses = checkpoint["val_losses"]

100%|██████████| 569/569 [03:28<00:00,  2.73it/s]


Epoch 1/10, Train Loss: nan, Validation Loss: nan


100%|██████████| 569/569 [03:27<00:00,  2.74it/s]


Epoch 2/10, Train Loss: nan, Validation Loss: nan


100%|██████████| 569/569 [03:28<00:00,  2.73it/s]


In [None]:
f, ax = plt.subplots(1, 1, figsize=(5,5))

ax.plot(train_losses, c="blue", label="Training")
ax.plot(val_losses, c="red", label="Test")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

## Evaluation Metrics

In [None]:
def run_evaluation_metrics(model, val_loader, device):
    truths = []
    predictions = []
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            batch_x = batch_x.to(device).long()
            batch_y = batch_y.to(device).half()
            outputs = model(batch_x)
            truths.extend(batch_y.cpu().numpy().tolist())
            predictions.extend(outputs.reshape(-1).cpu().detach().numpy().tolist())

    print(f"MSE: {mean_squared_error(truths, predictions):.2f}")
    print(f"Coefficient of determination: {r2_score(truths, predictions):.2f}")
    print(f"Pearson correlation: {np.corrcoef(truths, predictions)[0, 1]:.2f}")
    mean_val = np.mean(truths)
    std_val = np.std(truths)
    labels = truths > (mean_val+4*std_val)
    print(f"ROC AUC: {roc_auc_score(labels, predictions):.2f}")
    return truths, predictions

In [None]:
truths, predictions = run_evaluation_metrics(model, val_loader, device)

In [None]:
tmp_df = pd.DataFrame({"y": truths, r"$\hat{y}$": predictions})

# scatter plot
g = sns.JointGrid(x="y", y=r"$\hat{y}$", data=tmp_df)
g = g.plot_joint(plt.scatter, c="green", alpha=0.5);

# line: y_pred = y
y_line = np.linspace(np.min(truths), np.max(truths), 200)
g.ax_joint.plot(y_line, y_line, color="blue", linestyle="--");

# histograms
g = g.plot_marginals(sns.histplot, data=tmp_df, color="green", kde=False);

g.ax_joint.set_xlim(np.min(y_line), np.max(y_line))
g.ax_joint.set_ylim(np.min(y_line), np.max(y_line))

plt.show()