In [None]:
import torch  # Core PyTorch library for tensors and GPU acceleration
import torch.nn as nn  # Neural network layers
import torch.optim as optim  # Optimization algorithms (not used here but imported)
import pandas as pd  # For reading CSV data
from torch.utils.data import DataLoader, TensorDataset  # Create dataset + loader combos
from torch.nn.utils.rnn import pack_padded_sequence  # Helps GRUs handle variable-length sequences
from tqdm import tqdm  # Progress bars because suffering should be visible
import os  # File handling because computers need directions
from google.colab import drive  # Mount Google Drive for I/O

drive.mount('/content/drive')  # Plug Drive into the Colab robot
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Pick GPU if blessed by the silicon gods

config = {
    "max_length": 50,  # Max sequence length after padding
    "embedding_dim": 128,  # Size of learned amino acid embeddings
    "hidden_dim": 512,  # GRU hidden size (big brain mode)
    "latent_dim": 64,  # Size of VAE latent vector
    "batch_size": 4096  # Huge batch cause this is just encoding
}

class PeptideTokenizer:
    def __init__(self):
        self.chars = ['<PAD>', '<SOS>', '<EOS>', '<UNK>'] + list("ACDEFGHIKLMNPQRSTVWY")  # Token inventory
        self.char_to_idx = {c: i for i, c in enumerate(self.chars)}  # Map char -> ID
        self.idx_to_char = {i: c for i, c in enumerate(self.chars)}  # Map ID -> char
        self.vocab_size = len(self.chars)  # Number of tokens

    def encode_batch(self, seqs, max_len):
        batch_tensor = torch.full((len(seqs), max_len), self.char_to_idx['<PAD>'], dtype=torch.long)  # Pre-fill with PADs
        lengths = []  # Store each sequence length
        for i, seq in enumerate(tqdm(seqs, desc="Tokenizing")):
            s = seq[:max_len - 2]  # Account for SOS and EOS
            idx = [self.char_to_idx.get(aa, self.char_to_idx['<UNK>']) for aa in s]  # Convert string -> indices
            full_seq = [self.char_to_idx['<SOS>']] + idx + [self.char_to_idx['<EOS>']]  # Add special tokens
            length = len(full_seq)
            batch_tensor[i, :length] = torch.tensor(full_seq)  # Insert actual sequence
            lengths.append(length)  # Track true length
        return batch_tensor, torch.tensor(lengths, dtype=torch.long)  # Return padded data + lengths

    def decode(self, indices):
        res = []  # Collect characters
        for idx in indices:
            if idx == self.char_to_idx['<EOS>']: break  # Stop at EOS
            if idx in [self.char_to_idx['<SOS>'], self.char_to_idx['<PAD>']]: continue  # Skip fluff tokens
            res.append(self.idx_to_char[idx])  # Convert back to character
        return "".join(res)  # Return nice string

class VAE(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)  # Turn ints into vectors
        self.encoder_gru = nn.GRU(emb_dim, hidden_dim, batch_first=True, bidirectional=True)  # BiGRU encoder
        self.fc_mu = nn.Linear(hidden_dim * 2, latent_dim)  # Mean of latent Gaussian
        self.fc_logvar = nn.Linear(hidden_dim * 2, latent_dim)  # Log variance (unused in this trimmed version)
        self.decoder_input = nn.Linear(latent_dim, hidden_dim)  # Latent -> GRU init state
        self.decoder_gru = nn.GRU(emb_dim, hidden_dim, batch_first=True)  # Decoder GRU
        self.fc_out = nn.Linear(hidden_dim, vocab_size)  # Output token distribution

    def forward(self, x, lengths):
        embedded = self.embedding(x)  # Convert tokens to vectors
        packed_input = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)  # Compress padding
        _, hidden = self.encoder_gru(packed_input)  # GRU returns hidden states
        hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)  # Combine forward + backward GRU ends
        mu = self.fc_mu(hidden_cat)  # Compute latent mean only
        return mu  # For fast inference

MODEL_PATH = "/content/drive/MyDrive/AMP-Generation/models/vae_FINAL_epoch20.pth"  # Saved VAE weights
STATS_PATH = "/content/drive/MyDrive/AMP-Generation/models/vae_correction_stats.pth"  # Scaling stats
POS_DATA   = "/content/drive/MyDrive/AMP-Generation/data/pos_data.csv"  # Positive samples
NEG_DATA   = "/content/drive/MyDrive/AMP-Generation/data/neg_data.csv"  # Negative samples

tokenizer = PeptideTokenizer()  # Tokenizer instance
model = VAE(
    vocab_size=tokenizer.vocab_size,
    emb_dim=config["embedding_dim"],
    hidden_dim=config["hidden_dim"],
    latent_dim=config["latent_dim"]
).to(device)  # Move model to GPU if possible

checkpoint = torch.load(MODEL_PATH, map_location=device)  # Load model checkpoint
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])  # Standard training save format
else:
    model.load_state_dict(checkpoint)  # Raw state dict
model.eval()  # Switch to inference mode

# Load stats
stats = torch.load(STATS_PATH, map_location=device)  # Load normalization stats
mu_mean = stats['shift'].to(device)  # Mean for standardization
mu_std = stats['scale'].to(device)  # Std deviation

def extract_and_normalize(csv_path, output_filename):
    print(f"Converting {os.path.basename(csv_path)} to Latent Space...")  # Progress update
    df = pd.read_csv(csv_path)  # Load dataset
    col = 'sequence' if 'sequence' in df.columns else df.columns[0]  # Determine column name
    seqs = df[col].astype(str).tolist()  # Convert to string list
    data_tensor, lengths_tensor = tokenizer.encode_batch(seqs, config["max_length"])  # Tokenize batch
    loader = DataLoader(TensorDataset(data_tensor, lengths_tensor), batch_size=config["batch_size"], shuffle=False)  # Create loader

    latents_list = []  # Collect output tensors
    with torch.no_grad():  # Disable gradient tracking
        for batch, lengths in tqdm(loader):  # Iterate over batches
            batch = batch.to(device)  # Move to GPU
            mu = model(batch, lengths)  # Encode into latent mean
            norm_mu = (mu - mu_mean) / mu_std  # Normalize using saved stats
            latents_list.append(norm_mu.cpu())  # Store on CPU

    final_tensor = torch.cat(latents_list, dim=0)  # Combine all latent vectors
    save_path = os.path.join("/content/drive/MyDrive/AMP-Generation/data/", output_filename)  # Output path
    torch.save(final_tensor, save_path)  # Save tensor to file
    print(f"Success! Saved to {save_path}")  # Confirmation

extract_and_normalize(POS_DATA, "latent_cond_pos.pth")  # Process positive dataset
extract_and_normalize(NEG_DATA, "latent_cond_neg.pth")  # Process negative dataset
