In [None]:
import zipfile
import h5py
import os
import sys
import time
import timeit
import pickle

import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

from tqdm import tqdm

from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.drawOptions.addAtomIndices = True

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchinfo

torch.cuda.is_available()

In [None]:
# Optimal Thread Determination
runtimes = []
threads = [1] + [t for t in range(2, 49, 2)]
num_runs = 5

for t in tqdm(threads):
    stats_rt = []
    for run in range(num_runs):
        torch.set_num_threads(t)
        r = timeit.timeit(setup = "import torch; x = torch.randn(1024, 1024); y = torch.randn(1024, 1024)", stmt="torch.mm(x, y)", number=100)
        stats_rt.append(r)
    
    runtimes.append(np.mean(stats_rt))

optimal_num_threads = threads[np.argmin(runtimes)]
print('OPTIMAL NUM THREADS:', threads[np.argmin(runtimes)])
plt.plot(threads, runtimes)
plt.xlabel('Num Threads')
plt.ylabel('Run Time (s)')

In [None]:
ziploc = '/home/btpq/bt308495/Thesis/molecular-vae/data/processed.zip'
contentsdest = '/localdisk/bt308495/molecular-vae/data/'

### Unzip file to 'contentdest'
# with zipfile.ZipFile(ziploc, 'r') as zpf:
#     zpf.extractall(contentsdest)

### Load data from unzipped file
with h5py.File(os.path.join(contentsdest, 'processed.h5'), 'r') as data:
    data_train =  data['data_train'][:]
    data_test =  data['data_test'][:]
    charset =  data['charset'][:]

In [None]:
# Create an additional validation set (80/5/15 train/validate/test split)
data_valid, data_test = train_test_split(data_test, test_size=0.75, shuffle=False)

In [None]:
data_train.shape, data_valid.shape, data_test.shape, charset.shape

-------

In [None]:
def one_hot_to_smile(onehot_vector, character_set):
    ### Take a one-hot vector/tensor (MAX SMILE LENGTH, CHARSET LENGTH) and convert it to a smile string
    assert onehot_vector.shape[1] == character_set.size, 'Onehot length doesnt match character_set length'
    indicies = np.argmax(onehot_vector, axis=1)
    return b''.join(character_set[indicies])

In [None]:
test = data_train[0].copy()

In [None]:
Chem.MolFromSmiles(one_hot_to_smile(data_train[30], charset))

In [None]:
smis = [one_hot_to_smile(d, charset) for d in data_train]

----

# VAE Model

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.CHARSET_LEN = 33
        self.INPUT_SIZE = 120
        self.LATENT_DIM = 292

        ### ENCODING
        # Convolutional Layers
        self.conv_1 = nn.Conv1d(self.INPUT_SIZE, 9, kernel_size=9)
        self.conv_2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv_3 = nn.Conv1d(9, 10, kernel_size=11)

        # Fully Connected Layer
        self.linear_0 = nn.Linear(70, 435)

        # Mean and Variance Latent Layers
        self.mean_linear_1 = nn.Linear(435, self.LATENT_DIM)
        self.var_linear_2 = nn.Linear(435, self.LATENT_DIM)
        
        ### DECODING
        # Fully connected, GRU RNN, Fully connected layers
        # 3 sequential GRUs of hidden size 501. batch_first = True implies batch_dim first. 
        # Then, inputs into GRU are of shape [batch_size, seq_length (INPUT_SIZE, 120), Hin (LATENT_DIM, 292)]
        self.linear_3 = nn.Linear(self.LATENT_DIM, self.LATENT_DIM)
        self.stacked_gru = nn.GRU(self.LATENT_DIM, 501, 3, batch_first=True)
        self.linear_4 = nn.Linear(501, self.CHARSET_LEN)
        
        ### ACTIVATION and OUTPUT 
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

    def encode(self, x):
        # Convolutional
        x = self.relu(self.conv_1(x))
        x = self.relu(self.conv_2(x))
        x = self.relu(self.conv_3(x))

        # Flatten the Convultional output [batch_size, 10, 70] to make an input [batch_size, 10*7] for a fully connected layer
        x = x.view(x.size(0), -1)
        x = F.selu(self.linear_0(x))

        # Mean and logvariance latent vectors [batch_size, latent_dim]
        m, v = self.mean_linear_1(x), self.var_linear_2(x) 
        return m, v

    def reparameterize(self, mu_z, logvar_z):
        ## Sample a latent vector 'z', given its mean and std vectors
        # z ~ N(mu, std), is non-differentiable. While z ~ mu + eps (dot) std, where eps ~ N(0, 1), is differentiable. Why?
        # Since mu and std are now deterministic model outputs that can be trained by backprop, while the 'randomness' implicitly enters via the standard normal error/epsilon term
        gamma = 1e-2 # not sure why this is here...?
        epsilon = gamma * torch.randn_like(logvar_z) # 0 mean, unit variance noise of shape z_logvar
        std = torch.exp(0.5 * logvar_z)
        z = mu_z + epsilon * std
        return z

    def decode(self, z):
        z = F.selu(self.linear_3(z))

        # Since the GRU, when unrolled in 'time', consists of 120 NNs each sequentially processing data... we have to send 120 copies through it.
        # By repeating the tensor z self.INPUT_SIZE times along the sequence length dimension, we are effectively creating a sequence of self.INPUT_SIZE time steps,
        # each with the same latent representation. This setup allows the GRU to process this "sequence" of repeated tensors, even though the actual sequence content
        # is the same at each time step. This kind of setup can be useful for example when:

        # 1. Information Propagation: 
        # Sometimes you want to ensure that a certain piece of information is propagated consistently through the entire sequence. By using repeated tensors, you can
        # ensure that the same information is available to the network at every time step, allowing the network to incorporate this information throughout the entire sequence.

        # 2. Fixed-Size Context: If you have a fixed-size context or control signal that should influence the processing of the entire sequence, you can repeat this
        # context along the sequence length dimension. This way, the network can take into account this context when making decisions at every time step.

        # Note on use of contiguous()
        # contiguous means 'sharing a common border; touching'
        # In the context of pytorch, contiguous means not only contiguous in memory (each element in a tensor is stored right next to the other, in a block),
        # but also in the same order in memory as the indices order. For example doing a transposition doesn't change the data in memory (data at (1, 4) doesnt swap
        # memory places when its transposed to (4, 1)), it simply changes the map from indices to memory pointers (what index corresponds to what data is swapped instead,
        # leaving memory untouched). If you then apply contiguous() it will change the data in memory so that the map from indices to memory location is the canonical one.
        # For certain pytorch operations, contiguously stored tensors are required! Else a runtime error is encountered (RuntimeError: input is not contiguous).

        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, self.INPUT_SIZE, 1) # Reshape z from [batch_size, latent_dim] to [batch_size, seq_len (120), latent_dim]
        output, hs = self.stacked_gru(z) # hs represents the hidden state of the last time step of the GRU

        # Output is flattened along 1st two dimensions [batch_size, seq_len, hout] -> [batch_size * seq_len, hout]
        # Softmax is then applied row-wise/sample-wise following a linear transform
        # before the vector is then unflatten back to the original [batch_size, seq_len, charset_len]

        # The purpose of this initial flattening is:
        # In the context of a sequence-to-sequence model, each time step's output from the RNN (or a similar sequential model) represents the model's understanding of the
        # data at that particular moment. When you collapse the dimensions and reshape the tensor to (batch_size * sequence_length, num_features), you effectively create 
        # a flat sequence where each element corresponds to a time step's output for a specific sample in the batch.
        # Then applying a linear transformation like self.linear_4 at this stage means that the same linear transformation is applied to each element in the flattened sequence
        # independently (as if the new batch size is of shape batch_size * seq_len)! This is independent in the sense that the transformation doesn't consider interactions
        # between different time steps or different samples within the batch. It's a per-element operation.

        # By applying a linear transformation independently to each element, the model has the flexibility to learn different weights for different features at different time steps.
        # These weights can capture complex relationships within each time step's output, such as identifying important features or capturing patterns specific to that moment.
        # We then reshape back to regain the sequence structure...
        out_independent = output.contiguous().view(-1, output.size(-1))
        y0 = F.softmax(self.linear_4(out_independent), dim=1)
        y = y0.contiguous().view(output.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        mu_z, logvar_z = self.encode(x)
        z = self.reparameterize(mu_z, logvar_z)
        xhat = self.decode(z)
        return xhat, mu_z, logvar_z

In [None]:
batch_size = 1
torchinfo.summary(VAE(), input_size=(batch_size, 120, 33))

In [None]:
# test_z = torch.from_numpy(np.array((np.random.randn(292), np.random.randn(292)))).to(torch.float32)
# test_model = VAE()
# test_model.decode(test_z)

In [None]:
# data_train_tensor[0].size(1)
# data_train_tensor[:2].shape, data_train_tensor[:2].view(data_train_tensor[:2].shape[0], -1).shape

In [None]:
data_train_tensor = torch.from_numpy(data_train)
data_train_tensor_loader = torch.utils.data.TensorDataset(torch.from_numpy(data_train))

In [None]:
# Calling the model() on data runs the forward pass through the network
VAE()(data_train_tensor[:2])

### Training

In [None]:
def variational_loss(x, reconstructed_x_mean, mu_z, logvar_z):
    BCE = F.binary_cross_entropy(reconstructed_x_mean, x, reduction='sum') # Pixel-wise reconstruction loss, no-mean taken to match KL-div
    KLD = -0.5 * torch.sum(1. + logvar_z - mu_z.pow(2) - logvar_z.exp()) # KL divergence of the latent space distribution

    return BCE + KLD, BCE, KLD

In [None]:
data_train_tensor = torch.from_numpy(data_train)
data_valid_tensor = torch.from_numpy(data_valid)
data_test_tensor = torch.from_numpy(data_test)

data_train_tensor_loader = torch.utils.data.TensorDataset(data_train_tensor)
train_loader = torch.utils.data.DataLoader(data_train_tensor_loader, batch_size=250, shuffle=True)

In [None]:
torch.manual_seed(42)
epochs = 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = VAE().to(device)
optimizer = optim.Adam(model.parameters())

In [None]:
torch.set_num_threads(4)

In [None]:
def train_epoch(epoch):
    print(f'################ epoch {epoch} ################')
    start = time.perf_counter()
    model.train() # Tell model we're in train mode, as opposed to eval mode
    training_loss, training_bce_loss, training_kld_loss = 0, 0, 0

    for batch_indx, X in enumerate(tqdm(train_loader)):
        # Reset gradients after each batch and send data to GPU if availables
        optimizer.zero_grad()
        X = X[0].to(device)

        # Forward pass through the model
        Xhat, mu_z, logvar_z = model(X)

        # Determine Loss, perform backward pass, and update weights
        loss, bceloss, kldloss = variational_loss(X, Xhat, mu_z, logvar_z)
        loss.backward()
        optimizer.step()
        
        training_loss += loss.item()
        training_bce_loss += bceloss.item()
        training_kld_loss += kldloss.item()

    # Get model performance on validation set
    X_valid = data_valid_tensor.to(device)
    Xhat_valid, mu_valid, logvar_valid = model(X_valid)
    validation_loss, _, _ = variational_loss(X_valid, Xhat_valid, mu_valid, logvar_valid)

    # Summary of training epoch
    mean_training_loss = training_loss / len(train_loader.dataset)
    mean_training_bce_loss = training_bce_loss / len(train_loader.dataset)
    mean_training_kld_loss = training_kld_loss / len(train_loader.dataset)
    mean_validation_loss = validation_loss.item() / data_valid_tensor.shape[0]

    test_points = X[0].cpu(), Xhat[0].cpu().detach() # Access a datapoint, send to cpu, and remove gradient
    test_smiles = [one_hot_to_smile(t.numpy(), charset) for t in test_points]

    print(f'Epoch took: {(time.perf_counter() - start) / 60.} mins')
    print('Mean Training Loss:', mean_training_loss)
    print('Mean Validation Loss:', mean_validation_loss)
    print('---------------------')
    print('Sampled Input, Ouput Smiles:')
    _ = [print(t) for t in test_smiles]
    
    return (mean_training_loss, mean_training_bce_loss, mean_training_kld_loss), mean_validation_loss

In [None]:
tls, vls = [], []
for epoch in range(1, epochs+1):
    training_losses, validation_loss = train_epoch(epoch)
    tls.append(training_losses), vls.append(validation_loss)

In [None]:
# with open('VAE_losses.pckl', 'wb') as f:
#     pickle.dump([tls, vls], f)

with open('VAE_losses.pckl', 'rb') as f:
    tls, vls = pickle.load(f)

In [None]:
# Save model params
PATH = 'test_model.pt'
torch.save(model.state_dict(), PATH)

# Load model
# model = VAE()
# model.load_state_dict(torch.load(PATH))
# # model.eval()

In [None]:
# train_epoch(0)

-------

# Analyzing Trained Model

In [None]:
# Load model
PATH = '/home/btpq/bt308495/Thesis/VAE_model_parmas.pt'
model = VAE()
model.load_state_dict(torch.load(PATH))
model.eval()

with open('/home/btpq/bt308495/Thesis/VAE_losses.pckl', 'rb') as f:
    tls, vls = pickle.load(f)

In [None]:
tls = np.array(tls)
vls = np.array(vls)

train_loss, bce_loss, kld_loss = tls[:, 0], tls[:, 1], tls[:, 2]

In [None]:
plt.plot(range(len(train_loss)), train_loss, color='blue', label='Train')
plt.plot(range(len(vls)), vls, color='green', label='Validate')
plt.xlabel('Epoch')
plt.ylabel('Total Loss')

In [None]:
plt.plot(range(len(bce_loss)), bce_loss, color='red', label='BCE')
plt.plot(range(len(kld_loss)), kld_loss, color='purple', label='KLD')
plt.xlabel('Epoch')
plt.ylabel('Broken Down Loss')

In [None]:
# See how model performs on test smiles
Xhat_test, mu_test, logvar_test = model(data_test_tensor)
test_loss, _, _ = variational_loss(data_test_tensor, Xhat_test, mu_test, logvar_test)

In [None]:
print('Training loss (proxy):', train_loss[-1])
print('Validation loss (proxy):', vls[-1])
print('Test loss:', test_loss)