In [1]:
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root not in sys.path:
    sys.path.insert(0, root)

In [2]:
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from cvae.models import CVAESimpleEnc
from cvae.datasets import CVAEAllDataset
from cvae.utils import (
    CONDITION_LENGTH, MAX_FASTA_LENGTH, MAX_SEQ_LENGTH, ALPHABET, PAD_TOKEN_ID,
    finetune_collate_fn, vae_loss_fn_with_cond
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device, torch.cuda.device_count())

cuda 2


In [3]:
batch_size = 2048
df = pd.read_csv("../../clean_data/merged_all.csv", keep_default_na=False, na_values=[''])
train_val_df, test_df = train_test_split(df, test_size=0.1, stratify=df['length'], random_state=42)
train_df, val_df = train_test_split(train_val_df, test_size=0.1, stratify=train_val_df['length'], random_state=42)

In [4]:
FEATURES = [
    "length",
    "is_assembled",
    "ap",
    "has_beta_sheet_content",
    "hydrophobic_moment",
    "net_charge",
]

In [5]:
train_dataset = CVAEAllDataset(train_df, max_fasta_length=MAX_FASTA_LENGTH, random_mask=True)
val_dataset   = CVAEAllDataset(val_df, max_fasta_length=MAX_FASTA_LENGTH)
test_dataset  = CVAEAllDataset(test_df, max_fasta_length=MAX_FASTA_LENGTH)

train_loader_ft = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=finetune_collate_fn)
val_loader_ft   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=finetune_collate_fn)
test_loader_ft  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=finetune_collate_fn)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device, torch.cuda.device_count())

cvae_model = CVAESimpleEnc(
    encoder_hidden_dim=256,
    num_encoder_layers=2,
    vocab_size=len(ALPHABET),
    latent_dim=24,
    cond_dim=CONDITION_LENGTH,
    max_seq_length=MAX_SEQ_LENGTH,
    decoder_hidden_dim=256,
    num_decoder_layers=2,
    nhead=8,
    dropout=0.1)

pretrained_state_dict = torch.load("../cvae/chkpts/finetuned_cvae.pt", map_location=device, weights_only=True)
cvae_model.load_state_dict(pretrained_state_dict)

cuda 2


<All keys matched successfully>

In [7]:
cvae_model.eval()

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    cvae_model = nn.DataParallel(cvae_model)

cvae_model.to(device)
print("CVAE initialized with finetuned architecture")

Using 2 GPUs
CVAE initialized with finetuned architecture


In [8]:
def evaluate(model, dataloader, pad_idx, device, kl_weight, cond_weight):
    model.eval()
    total = {
        "loss":    0.0,
        "recon":   0.0,
        "kl":      0.0,
        "mse":     0.0,
        "tokens":  0,
    }

    with torch.no_grad():
        for tokens, tgt_tokens, conds, mask in dataloader:
            B = tokens.size(0)
            tokens     = tokens.to(device)
            tgt_tokens = tgt_tokens.to(device)
            conds      = conds.to(device)
            mask       = mask.to(device)

            logits, mu, logvar, prior_mu, prior_logvar, bc_logit, cc_pred, mask_logit = \
                model(tokens, conds, mask)

            #vocab_size = logits.size(-1)

            loss, recon, kl, _ = vae_loss_fn_with_cond(
                logits=logits.view(-1, logits.size(-1)),
                tgt=tgt_tokens.view(-1),
                mu=mu, logvar=logvar, prior_mu=prior_mu, prior_logvar=prior_logvar,
                bc_logit=bc_logit, cc_pred=cc_pred, mask_logit=mask_logit,
                cond=conds, mask=mask,
                pad_idx=PAD_TOKEN_ID, kl_weight=kl_weight, lambda_bin=1.0, lambda_cont=1.0
            )

            total["loss"]  += loss.item() * B
            total["recon"] += recon.item() * B
            total["kl"]    += kl.item() * B
            total["tokens"] += (tgt_tokens != pad_idx).sum().item()

    N = len(dataloader.dataset)
    total["loss"]  /= N
    total["recon"] /= N
    total["kl"]    /= N

    total["ppl"] = float(torch.exp(torch.tensor(total["recon"] * total["tokens"] / total["tokens"])))

    return total


metrics = evaluate(
    model=cvae_model,
    dataloader=val_loader_ft,
    pad_idx=PAD_TOKEN_ID,
    device=device,
    kl_weight=0.02,
    cond_weight=50.0
)

print("==== Validation metrics ====")
print(f"Loss: {metrics['loss']:.4f} | Recon: {metrics['recon']:.4f} | KL: {metrics['kl']:.4f}")
print(f"Perplexity: {metrics['ppl']:.2f}")

  output = torch._nested_tensor_from_mask(


==== Validation metrics ====
Loss: 1.3964 | Recon: 1.0401 | KL: 17.8119
Perplexity: 2.83


In [9]:
metrics = evaluate(
    model=cvae_model,
    dataloader=test_loader_ft,
    pad_idx=PAD_TOKEN_ID,
    device=device,
    kl_weight=0.02,
    cond_weight=50.0
)

print("==== Test metrics ====")
print(f"Loss: {metrics['loss']:.4f} | Recon: {metrics['recon']:.4f} | KL: {metrics['kl']:.4f}")
print(f"Perplexity: {metrics['ppl']:.2f}")

==== Test metrics ====
Loss: 1.3926 | Recon: 1.0351 | KL: 17.8712
Perplexity: 2.82
