<a href="https://colab.research.google.com/github/mkoko22/AMP_Generation_using_LDM/blob/main/generation/Decoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
from tqdm import tqdm
from google.colab import drive
import os

drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

LATENT_PATH = "/content/drive/MyDrive/AMP-Generation/data/generated_latent_amp_denorm.pth"
VAE_PATH    = "/content/drive/MyDrive/AMP-Generation/checkpoints/vae_FINAL_epoch20.pth"
SAVE_PATH   = "/content/drive/MyDrive/AMP-Generation/data/generated_amp_sequences.txt"

MAX_LEN = 50
EMB_DIM = 128
HIDDEN_DIM = 512
LATENT_DIM = 64


class PeptideTokenizer:
    def __init__(self):
        self.chars = ['<PAD>', '<SOS>', '<EOS>', '<UNK>'] + list("ACDEFGHIKLMNPQRSTVWY")
        self.char_to_idx = {c: i for i, c in enumerate(self.chars)}
        self.idx_to_char = {i: c for i, c in enumerate(self.chars)}
        self.vocab_size = len(self.chars)

    def decode(self, indices):
        res = []
        for idx in indices:
            if idx == self.char_to_idx['<EOS>']:
                break
            if idx in (self.char_to_idx['<PAD>'], self.char_to_idx['<SOS>']):
                continue
            res.append(self.idx_to_char[idx])
        return "".join(res)

tokenizer = PeptideTokenizer()


class VAE(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, latent_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)

        # encoder
        self.encoder_gru = nn.GRU(
            emb_dim, hidden_dim,
            batch_first=True, bidirectional=True
        )
        self.fc_mu = nn.Linear(hidden_dim * 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim * 2, latent_dim)

        # decoder
        self.decoder_input = nn.Linear(latent_dim, hidden_dim)
        self.decoder_gru = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    @torch.no_grad()
    def decode_latent(self, z, max_len=50):
        batch_size = z.size(0)
        hidden = self.decoder_input(z).unsqueeze(0)

        inputs = torch.full(
            (batch_size, 1),
            tokenizer.char_to_idx['<SOS>'],
            device=z.device,
            dtype=torch.long
        )

        generated = []

        for _ in range(max_len):
            emb = self.embedding(inputs[:, -1:])
            out, hidden = self.decoder_gru(emb, hidden)
            logits = self.fc_out(out.squeeze(1))
            next_token = torch.argmax(logits, dim=-1)
            inputs = torch.cat([inputs, next_token.unsqueeze(1)], dim=1)
            generated.append(next_token)

        return torch.stack(generated, dim=1)


model = VAE(
    vocab_size=tokenizer.vocab_size,
    emb_dim=EMB_DIM,
    hidden_dim=HIDDEN_DIM,
    latent_dim=LATENT_DIM
).to(device)

state = torch.load(VAE_PATH, map_location=device)
if isinstance(state, dict) and "model_state_dict" in state:
    state = state["model_state_dict"]

model.load_state_dict(state)
model.eval()

print("VAE loaded successfully.")

# Load latents
z = torch.load(LATENT_PATH, map_location=device)
print(f"Loaded {z.shape[0]} latent vectors")

# Decode
decoded_sequences = []
BATCH_SIZE = 1024

for i in tqdm(range(0, z.size(0), BATCH_SIZE), desc="Decoding"):
    batch_z = z[i:i + BATCH_SIZE].to(device)
    tokens = model.decode_latent(batch_z, MAX_LEN).cpu().numpy()

    for seq in tokens:
        peptide = tokenizer.decode(seq)
        if len(peptide) > 0:
            decoded_sequences.append(peptide)

# Save
with open(SAVE_PATH, "w") as f:
    for s in decoded_sequences:
        f.write(s + "\n")

print(f"Decoded {len(decoded_sequences)} sequences")
print(f"Saved to: {SAVE_PATH}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
VAE loaded successfully.
Loaded 5120 latent vectors


Decoding: 100%|██████████| 5/5 [00:00<00:00,  9.47it/s]


Decoded 5120 sequences
Saved to: /content/drive/MyDrive/AMP-Generation/data/generated_amp_sequences.txt
