<a href="https://colab.research.google.com/github/mkoko22/AMP_Generation_using_LDM/blob/main/generation/Vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install wandb -q

import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import wandb
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pack_padded_sequence
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {device}")

from google.colab import drive
drive.mount('/content/drive')


config = {
    "project_name": "Peptide-Generation-VAE",
    "data_path_train": "/content/drive/MyDrive/AMP-Generation/data/VAE_train.csv",
    "data_path_val":   "/content/drive/MyDrive/AMP-Generation/data/VAE_val.csv",
    "checkpoint_dir":  "/content/drive/MyDrive/AMP-Generation/checkpoints",
    "model_dir":       "/content/drive/MyDrive/AMP-Generation/models",
    "checkpoint_file": "vae_checkpoint.pth",

    "max_length": 50,
    "vocab_size": 25,
    "embedding_dim": 128,
    "hidden_dim": 512,
    "latent_dim": 64,
    "batch_size": 4096,
    "learning_rate": 8e-4,
    "epochs": 20,
    "kl_weight": 0.002,
}

os.makedirs(config["checkpoint_dir"], exist_ok=True)
os.makedirs(config["model_dir"], exist_ok=True)

try:
    wandb.finish()
except:
    pass
wandb.init(project=config["project_name"], config=config)


class PeptideTokenizer:
    def __init__(self):
        self.chars = ['<PAD>', '<SOS>', '<EOS>', '<UNK>'] + list("ACDEFGHIKLMNPQRSTVWY")
        self.char_to_idx = {c: i for i, c in enumerate(self.chars)}
        self.idx_to_char = {i: c for i, c in enumerate(self.chars)}
        self.vocab_size = len(self.chars)

    def encode_batch(self, seqs, max_len):
        batch_tensor = torch.full((len(seqs), max_len), self.char_to_idx['<PAD>'], dtype=torch.long)
        lengths = []
        for i, seq in enumerate(tqdm(seqs, desc="Tokenizing")):
            s = seq[:max_len - 2]
            idx = [self.char_to_idx.get(aa, self.char_to_idx['<UNK>']) for aa in s]
            full_seq = [self.char_to_idx['<SOS>']] + idx + [self.char_to_idx['<EOS>']]
            length = len(full_seq)
            batch_tensor[i, :length] = torch.tensor(full_seq)
            lengths.append(length)
        return batch_tensor, torch.tensor(lengths, dtype=torch.long)

    def decode(self, indices):
        res = []
        for idx in indices:
            if idx == self.char_to_idx['<EOS>']: break
            if idx in [self.char_to_idx['<SOS>'], self.char_to_idx['<PAD>']]: continue
            res.append(self.idx_to_char[idx])
        return "".join(res)

tokenizer = PeptideTokenizer()
config["vocab_size"] = tokenizer.vocab_size

def load_and_process_data(path):
    print(f"Reading {path}...")
    df = pd.read_csv(path)
    col = 'sequence' if 'sequence' in df.columns else df.columns[0]
    seqs = df[col].astype(str).tolist()
    data_tensor, lengths_tensor = tokenizer.encode_batch(seqs, config["max_length"])
    return TensorDataset(data_tensor, lengths_tensor)

train_data = load_and_process_data(config["data_path_train"])
val_data = load_and_process_data(config["data_path_val"])

train_loader = DataLoader(train_data, batch_size=config["batch_size"], shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=config["batch_size"], shuffle=False, num_workers=0, pin_memory=True)


class VAE(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.encoder_gru = nn.GRU(emb_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc_mu = nn.Linear(hidden_dim * 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim * 2, latent_dim)
        self.decoder_input = nn.Linear(latent_dim, hidden_dim)
        self.decoder_gru = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, lengths):
        embedded = self.embedding(x)
        packed_input = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, hidden = self.encoder_gru(packed_input)
        hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        mu = self.fc_mu(hidden_cat)
        logvar = self.fc_logvar(hidden_cat)
        z = self.reparameterize(mu, logvar)

        decoder_hidden = self.decoder_input(z).unsqueeze(0)
        dec_input = embedded[:, :-1, :]
        outputs, _ = self.decoder_gru(dec_input, decoder_hidden)
        logits = self.fc_out(outputs)
        return logits, mu, logvar

model = VAE(config["vocab_size"], config["embedding_dim"], config["hidden_dim"], config["latent_dim"]).to(device)
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
scaler = torch.amp.GradScaler('cuda')

start_epoch = 0
checkpoint_full_path = os.path.join(config["checkpoint_dir"], config["checkpoint_file"])

if os.path.exists(checkpoint_full_path):
    print(f"Found checkpoint at {checkpoint_full_path}. Loading...")
    checkpoint = torch.load(checkpoint_full_path, map_location=device)

    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Successfully resumed from Epoch {start_epoch}")

    else:
        model.load_state_dict(checkpoint)
        print("Loaded weights from old format (no epoch info).")
        print("Assuming start_epoch = 0")
else:
    print("No checkpoint found. Starting fresh.")


def vae_loss_function(recon_x, x, mu, logvar):
    targets = x[:, 1:]
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.char_to_idx['<PAD>'], reduction='sum')
    recon_loss = criterion(recon_x.reshape(-1, config["vocab_size"]), targets.reshape(-1))
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss, kld_loss

def check_reconstruction(model):
    model.eval()
    with torch.no_grad():
        x, l = next(iter(val_loader))
        x = x[:2].to(device)
        l = l[:2]
        logits, _, _ = model(x, l)
        preds = torch.argmax(logits, dim=2)
        print("\nüîç SANITY CHECK:")
        for i in range(2):
            real = tokenizer.decode(x[i, 1:].cpu().numpy())
            pred = tokenizer.decode(preds[i].cpu().numpy())
            print(f"Target: {real}\nPred:   {pred}\n" + "-"*30)

print(f"Starting Training from Epoch {start_epoch+1}...")

for epoch in range(start_epoch, config["epochs"]):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Ep {epoch+1}", leave=False)

    for batch, lengths in progress_bar:
        batch = batch.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            logits, mu, logvar = model(batch, lengths)
            recon_loss, kld_loss = vae_loss_function(logits, batch, mu, logvar)
            loss = recon_loss + (config["kl_weight"] * kld_loss)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        progress_bar.set_postfix({"Loss": loss.item() / batch.size(0)})

    avg_loss = total_loss / len(train_loader.dataset)
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch, lengths in val_loader:
            batch = batch.to(device)
            logits, mu, logvar = model(batch, lengths)
            r, k = vae_loss_function(logits, batch, mu, logvar)
            val_loss += (r + config["kl_weight"] * k).item()
    avg_val_loss = val_loss / len(val_loader.dataset)

    wandb.log({"train_loss": avg_loss, "val_loss": avg_val_loss})
    print(f"Ep {epoch+1} | Train: {avg_loss:.4f} | Val: {avg_val_loss:.4f}")

    checkpoint_data = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss
    }
    torch.save(checkpoint_data, checkpoint_full_path)

    if (epoch+1) % 5 == 0:
        check_reconstruction(model)
        milestone_path = os.path.join(config["model_dir"], f"vae_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), milestone_path)

final_path = os.path.join(config["model_dir"], "vae_final_corrected.pth")
torch.save(model.state_dict(), final_path)
wandb.finish()
print(f"Training Complete! Model saved to {final_path}")

üöÄ Using Device: cuda


Reading /content/drive/MyDrive/AMP-Generation/data/VAE_train.csv...


Tokenizing:   0%|          | 0/2736731 [00:00<?, ?it/s]

Reading /content/drive/MyDrive/AMP-Generation/data/VAE_val.csv...


Tokenizing:   0%|          | 0/143986 [00:00<?, ?it/s]

üîÑ Found checkpoint at /content/drive/MyDrive/AMP-Generation/checkpoints/vae_checkpoint.pth. Loading...
Successfully resumed from Epoch 17
Starting Training from Epoch 18...


Ep 18:   0%|          | 0/669 [00:00<?, ?it/s]

Ep 18 | Train: 9.9619 | Val: 8.4445


Ep 19:   0%|          | 0/669 [00:00<?, ?it/s]

Ep 19 | Train: 9.5232 | Val: 11.1394


Ep 20:   0%|          | 0/669 [00:00<?, ?it/s]

Ep 20 | Train: 8.9510 | Val: 10.0266

üîç SANITY CHECK:
Target: MYLSGRGMDYASSWDMIEVVVLTQDKVAGSWPTEAYMDREYLK
Pred:   MYLGGRGMDYASSWDMIEVVVLTQDKVGGSWPTAAYMDREYLK
------------------------------
Target: MDKALKEFEGTVTDVEYDEDEGALITVNVFKGIVDKLYGSK
Pred:   MDKALKEFEGTVTDVYYDEDEGALITVNVFKGIVDLLYGSK
------------------------------


0,1
train_loss,‚ñà‚ñÖ‚ñÅ
val_loss,‚ñÅ‚ñà‚ñÖ

0,1
train_loss,8.95096
val_loss,10.02661


Training Complete! Model saved to /content/drive/MyDrive/AMP-Generation/models/vae_final_corrected.pth


In [None]:
MODEL_PATH = "/content/drive/MyDrive/AMP-Generation/models/vae_FINAL_epoch20.pth"
SCALER_PATH = "/content/drive/MyDrive/AMP-Generation/models/vae_correction_stats.pt"

model = VAE(24, 128, 512, 64).to(device)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

print("Scanning Latent Space to build Correction File...")
all_mus = []

with torch.no_grad():
    for batch, lengths in train_loader:
        batch = batch.to(device)
        embedded = model.embedding(batch)
        packed = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, hidden = model.encoder_gru(packed)
        hidden_cat = torch.cat((hidden[-2], hidden[-1]), dim=1)
        mu = model.fc_mu(hidden_cat)

        all_mus.append(mu)

all_mus = torch.cat(all_mus, dim=0)
global_mean = all_mus.mean(dim=0)
global_std = all_mus.std(dim=0)

print(f"  Detected Drift -> Mean: {global_mean.mean():.5f}")
print(f"  Detected Size  -> Std:  {global_std.mean():.5f}")


correction_data = {
    "shift": global_mean,
    "scale": 1.0 / (global_std + 1e-8)
}

torch.save(correction_data, SCALER_PATH)
print("-" * 30)
print(f"CORRECTION FILE SAVED: {SCALER_PATH}")

Scanning Latent Space to build Correction File...
   Detected Drift -> Mean: -0.01319
   Detected Size  -> Std:  1.28643 (Too small, needs scaling)
------------------------------
CORRECTION FILE SAVED: /content/drive/MyDrive/AMP-Generation/models/vae_correction_stats.pt
