In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
torch.set_default_device('cuda')


class GRUVAE(nn.Module):
    def __init__(
        self,
        input_size: int,   # dimension of each embedding at a timestep
        hidden_size: int,  # hidden dimension of GRU
        latent_size: int,  # dimension of latent space z
        num_layers: int = 1,
    ):
        super(GRUVAE, self).__init__()

        self.num_layers = num_layers
        self.latent_size = latent_size
        
        self.encoder_gru = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        
        self.hidden2mu = nn.Linear(hidden_size, latent_size)
        self.hidden2logvar = nn.Linear(hidden_size, latent_size)
        
        self.latent2hidden = nn.Linear(latent_size, hidden_size)
        self.decoder_gru = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.h2output = nn.Linear(hidden_size, input_size)

    def encode(self, x):
        # We only need the final hidden state from the GRU
        _, h_n = self.encoder_gru(x)
        
        # h_n: (num_layers, batch_size, hidden_size)
        # Let's take only the top layer
        h_n_top = h_n[-1]  # shape: (batch_size, hidden_size)
        
        mu = self.hidden2mu(h_n_top)
        logvar = self.hidden2logvar(h_n_top)
        return mu, logvar

    def reparameterize(self, mu, logvar):

        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, seq_len):
        """
        Decodes a latent vector z into a sequence of length seq_len.
        Args:
            z: (batch_size, latent_size)
            seq_len: int, the length of the output sequence
        Returns:
            outputs: (batch_size, seq_len, input_size)
        """
        # Transform latent vector to initial hidden state for GRU
        hidden = self.latent2hidden(z)         # (batch_size, hidden_size)
        hidden = hidden.unsqueeze(0).repeat(self.num_layers, 1, 1)  
        
        # We'll generate the sequence step by step.
        batch_size = z.size(0)
        outputs = []
        
        # Start with a zero vector as the "input" for each timestep
        input_step = torch.zeros(batch_size, 1, self.h2output.out_features, device=z.device)
        for t in range(seq_len):
            # Pass one step at a time
            out, hidden = self.decoder_gru(input_step, hidden)
            # out: (batch_size, 1, hidden_size)
            # Project back to embedding dimension
            step_output = self.h2output(out)   # (batch_size, 1, input_size)
            outputs.append(step_output)
            
            # The next input is the current output (autoregressive decoding)
            input_step = step_output

        # Concatenate along seq_len dimension
        outputs = torch.cat(outputs, dim=1)    # (batch_size, seq_len, input_size)
        return outputs

    def forward(self, x):

        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        seq_len = x.size(1)
        recon_x = self.decode(z, seq_len)
        return recon_x, mu, logvar

    def sample(self, batch_size=1, seq_len=10):

        z = torch.randn(batch_size, self.latent_size).cuda()

        # Decode to generate sequences
        with torch.no_grad():
            samples = self.decode(z, seq_len)
        return samples

In [2]:
model = GRUVAE(input_size=8, hidden_size=100, latent_size=20, num_layers=1).cuda()
state_dict = torch.load('/home/bo/HierQAC/pt_model.pth')
model.load_state_dict(state_dict)

# Set to evaluation mode
model.eval()

GRUVAE(
  (encoder_gru): GRU(8, 100, batch_first=True)
  (hidden2mu): Linear(in_features=100, out_features=20, bias=True)
  (hidden2logvar): Linear(in_features=100, out_features=20, bias=True)
  (latent2hidden): Linear(in_features=20, out_features=100, bias=True)
  (decoder_gru): GRU(8, 100, batch_first=True)
  (h2output): Linear(in_features=100, out_features=8, bias=True)
)

In [3]:
# Decode back to molecule
import sys
sys.path.append("./hgraph2graph")

from hgraph2graph.hgraph import *
import argparse
import torch

from hgraph import HierVAE, common_atom_vocab
from hgraph.hgnn import make_cuda
from hgraph2graph.preprocess import tensorize
from hgraph import PairVocab
from util_split_core_tail import iterative_cut
from rdkit import Chem
from tqdm import tqdm

#####################
# Load Core Encoder #
#####################

core_vocab_path = "hgraph2graph/vocab-cores.txt"
with open(core_vocab_path) as f:
    core_vocab = [x.strip("\r\n ").split() for x in f]
core_vocab = PairVocab(core_vocab, cuda=False)

args_core = argparse.Namespace(
    seed=7,
    rnn_type='LSTM',
    hidden_size=100,   
    embed_size=100,    
    latent_size=8,
    depthT=15,
    depthG=15,
    diterT=1,
    diterG=3,
    dropout=0.0,
    vocab=core_vocab,
    atom_vocab=common_atom_vocab
)

core_model = HierVAE(args_core).cuda()
core_model.load_state_dict(torch.load("hgraph2graph/ckpt/cores/model.ckpt.1000")[0])

#####################
# Load Tail Encoder #
#####################

tail_vocab_path = "hgraph2graph/vocab-tails.txt"
with open(tail_vocab_path) as f:
    tail_vocab = [x.strip("\r\n ").split() for x in f]
tail_vocab = PairVocab(tail_vocab, cuda=False)

args_tail = argparse.Namespace(
    seed=7,
    rnn_type='LSTM',
    hidden_size=100,   
    embed_size=100,    
    latent_size=8,
    depthT=15,
    depthG=15,
    diterT=1,
    diterG=3,
    dropout=0.0,
    vocab=tail_vocab,
    atom_vocab=common_atom_vocab
)

tail_model = HierVAE(args_tail).cuda()
tail_model.load_state_dict(torch.load("hgraph2graph/ckpt/tails/model.ckpt.800")[0])





<All keys matched successfully>

In [4]:
# Sample one Z

samples = model.sample(batch_size=1, seq_len=5)
samples.shape

torch.Size([1, 5, 8])

In [6]:
cores_z = samples[:,0,:]
tails_z = samples[0,1:,:]

core_mol = core_model.csample(cores_z, greedy=True)
tail_mols = tail_model.csample(tails_z, greedy=True)

0.3352620601654053


In [7]:
from util_reassemble_core_tail import attach_tails_to_core
print("core:", core_mol)
print("tails:", tail_mols)
whole_mol = attach_tails_to_core(core_mol[0], tail_mols)
print(time_1-time_0)
print("assembled:", whole_mol)

core: ['c1ccc(-c2cc[nH+]cc2)cc1']
tails: ['CCCCCCCCCOC(=O)C[NH3+]', 'CCCCCC[NH3+]', 'CCCCC[NH3+]', 'CCCCC[NH3+]']
0.0034019947052001953
assembled: CCCCCCCCCOC(=O)C[n+]1ccc(-c2ccccc2)cc1


In [33]:
from tqdm import tqdm
generated = []
for _ in tqdm(range(2000)):
    samples = model.sample(batch_size=1, seq_len=5)
    samples.shape
    cores_z = samples[:,0,:]
    tails_z = samples[0,1:,:]
    core_mol = core_model.csample(cores_z, greedy=True)
    tail_mols = tail_model.csample(tails_z, greedy=True)
    try:
        whole_mol = attach_tails_to_core(core_mol[0], tail_mols)
        generated.append(whole_mol)
    except:
        continue

100%|██████████| 2000/2000 [11:10<00:00,  2.98it/s]


In [8]:
# Batched Generation

# Suppose you want n_samples molecules in a batch, and each sample has (1 core + 4 tails) = 5 total.
n_samples = 8  # Just an example
seq_len = 5    # 1 core + 4 tails

# Sample all at once
samples = model.sample(batch_size=n_samples, seq_len=seq_len)
# samples.shape -> (n_samples, seq_len, latent_dim)

# Separate the first latent vector as core, remaining as tails
cores_z = samples[:, 0, :]     # Shape: (n_samples, latent_dim)
tails_z = samples[:, 1:, :]    # Shape: (n_samples, seq_len - 1, latent_dim)

# Decode *all* cores in one go
core_mols = core_model.csample(cores_z, greedy=True)
# core_mols is a list of length n_samples (or another iterable structure), each a molecule

# Decode *all* tails in one go
# Reshape tails_z to flatten out the 2nd dimension so you can pass them in a single batch
n_tails = seq_len - 1  # 4 in this example
latent_dim = tails_z.shape[-1]
tails_z_reshaped = tails_z.reshape(n_samples * n_tails, latent_dim)

tail_mols_flat = tail_model.csample(tails_z_reshaped, greedy=True)
# tail_mols_flat is length (n_samples * 4)

# Regroup the flat tail molecules back into n_samples chunks
# Each chunk has 4 tail molecules for that sample
tail_mols_grouped = [
    tail_mols_flat[i*n_tails : (i+1)*n_tails]
    for i in range(n_samples)
]

for i in range(n_samples):
    try:
        core_mol = core_mols[i]
        tails_for_this_core = tail_mols_grouped[i]  # 4 tails
        whole_mol = attach_tails_to_core(core_mol, tails_for_this_core)
        print(whole_mol)
    except Exception as e:
        print(f"Error generating molecule {i}: {e}")
        continue


CCCCCC[n+]1ccc(CC(c2ccccc2)c2cccc(C[N+](C)(C)C)c2)cc1
CCCCCCCC[n+]1cccc(CCNCCCCC[N+](C)(C)C)c1
CCCCCCCCCCC[n+]1ccc(-c2cc[n+](CC)cc2)cc1
CCN(C)c1ccc(C(c2ccccc2)c2ccc(N(C)C)cc2)cc1
CCCCCC[N+](C)(C)c1ccc(C(c2ccccc2)c2cccc(C[N+](C)(C)C)c2)cc1
CCCCCCCCCCCc1ccccc1
CCCCCCCCCC[n+]1cccc(-c2ccccc2)c1
CCCCCCCCCCCC[N+](C)(C)C


In [35]:
with open("outputs/generated-20241230.txt", "w") as file:
    for line in generated:
        file.write(line + "\n")
