In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pack_padded_sequence
from tqdm import tqdm
import os
from google.colab import drive

# 1. Setup Environment
drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = {
    "max_length": 50,
    "embedding_dim": 128,
    "hidden_dim": 512,
    "latent_dim": 64,
    "batch_size": 4096
}

# 3. CLASS DEFINITIONS
class PeptideTokenizer:
    def __init__(self):
        self.chars = ['<PAD>', '<SOS>', '<EOS>', '<UNK>'] + list("ACDEFGHIKLMNPQRSTVWY")
        self.char_to_idx = {c: i for i, c in enumerate(self.chars)}
        self.idx_to_char = {i: c for i, c in enumerate(self.chars)}
        self.vocab_size = len(self.chars)

    def encode_batch(self, seqs, max_len):
        batch_tensor = torch.full((len(seqs), max_len), self.char_to_idx['<PAD>'], dtype=torch.long)
        lengths = []
        for i, seq in enumerate(tqdm(seqs, desc="Tokenizing")):
            s = seq[:max_len - 2]
            idx = [self.char_to_idx.get(aa, self.char_to_idx['<UNK>']) for aa in s]
            full_seq = [self.char_to_idx['<SOS>']] + idx + [self.char_to_idx['<EOS>']]
            length = len(full_seq)
            batch_tensor[i, :length] = torch.tensor(full_seq)
            lengths.append(length)
        return batch_tensor, torch.tensor(lengths, dtype=torch.long)

    def decode(self, indices):
        res = []
        for idx in indices:
            if idx == self.char_to_idx['<EOS>']: break
            if idx in [self.char_to_idx['<SOS>'], self.char_to_idx['<PAD>']]: continue
            res.append(self.idx_to_char[idx])
        return "".join(res)

class VAE(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.encoder_gru = nn.GRU(emb_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc_mu = nn.Linear(hidden_dim * 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim * 2, latent_dim)
        self.decoder_input = nn.Linear(latent_dim, hidden_dim)
        self.decoder_gru = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, lengths):
        embedded = self.embedding(x)
        packed_input = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, hidden = self.encoder_gru(packed_input)
        hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        mu = self.fc_mu(hidden_cat)
        return mu # We only need MU for extraction

# 4. PATHS
MODEL_PATH = "/content/drive/MyDrive/AMP-Generation/models/vae_FINAL_epoch20.pth"
STATS_PATH = "/content/drive/MyDrive/AMP-Generation/models/vae_correction_stats.pth"
POS_DATA   = "/content/drive/MyDrive/AMP-Generation/data/pos_data.csv"
NEG_DATA   = "/content/drive/MyDrive/AMP-Generation/data/neg_data.csv"

# 5. INITIALIZE & LOAD
tokenizer = PeptideTokenizer()
model = VAE(
    vocab_size=tokenizer.vocab_size,
    emb_dim=config["embedding_dim"],
    hidden_dim=config["hidden_dim"],
    latent_dim=config["latent_dim"]
).to(device)

# Load model weights
checkpoint = torch.load(MODEL_PATH, map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
else:
    model.load_state_dict(checkpoint)
model.eval()

# Load stats
stats = torch.load(STATS_PATH, map_location=device)
mu_mean = stats['shift'].to(device)
mu_std = stats['scale'].to(device)

# 6. EXTRACTION FUNCTION
def extract_and_normalize(csv_path, output_filename):
    print(f"Converting {os.path.basename(csv_path)} to Latent Space...")
    df = pd.read_csv(csv_path)
    col = 'sequence' if 'sequence' in df.columns else df.columns[0]
    seqs = df[col].astype(str).tolist()
    data_tensor, lengths_tensor = tokenizer.encode_batch(seqs, config["max_length"])
    loader = DataLoader(TensorDataset(data_tensor, lengths_tensor), batch_size=config["batch_size"], shuffle=False)

    latents_list = []
    with torch.no_grad():
        for batch, lengths in tqdm(loader):
            batch = batch.to(device)
            mu = model(batch, lengths) # Using the modified forward for speed
            norm_mu = (mu - mu_mean) / mu_std # Apply Stats
            latents_list.append(norm_mu.cpu())

    final_tensor = torch.cat(latents_list, dim=0)
    save_path = os.path.join("/content/drive/MyDrive/AMP-Generation/data/", output_filename)
    torch.save(final_tensor, save_path)
    print(f"Success! Saved to {save_path}")

# 7. EXECUTE
extract_and_normalize(POS_DATA, "latent_cond_pos.pth")
extract_and_normalize(NEG_DATA, "latent_cond_neg.pth")

Mounted at /content/drive
Converting pos_data.csv to Latent Space...


Tokenizing: 100%|██████████| 22175/22175 [00:00<00:00, 48966.07it/s]
100%|██████████| 6/6 [00:01<00:00,  5.57it/s]


Success! Saved to /content/drive/MyDrive/AMP-Generation/data/latent_cond_pos.pth
Converting neg_data.csv to Latent Space...


Tokenizing: 100%|██████████| 22441/22441 [00:00<00:00, 54749.64it/s]
100%|██████████| 6/6 [00:00<00:00,  9.60it/s]

Success! Saved to /content/drive/MyDrive/AMP-Generation/data/latent_cond_neg.pth



