In [1]:
import molgrid
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.nn import init
import os
import matplotlib.pyplot as plt
import argparse
from models import *
import torch.optim as optim

In [2]:
# set some constants
batch_size = 5
datadir = '/scratch/shubham/crossdock_data'
fname = datadir+"/custom_cd.types" 

molgrid.set_random_seed(0)
torch.manual_seed(0)
np.random.seed(0)

In [3]:
'''
    Arguement Parsing
'''
parser = argparse.ArgumentParser()
# parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
# parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
# parser.add_argument("--batch_size", type=int, default=8, help="size of the batches")
# parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
# parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
# parser.add_argument("--img_height", type=int, default=128, help="size of image height")
# parser.add_argument("--img_width", type=int, default=128, help="size of image width")
# parser.add_argument("--channels", type=int, default=3, help="number of image channels")
# parser.add_argument("--latent_dim", type=int, default=8, help="number of latent codes")
# parser.add_argument("--sample_interval", type=int, default=400, help="interval between saving generator samples")
# parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between model checkpoints")
# parser.add_argument("--lambda_pixel", type=float, default=10, help="pixelwise loss weight")
# parser.add_argument("--lambda_latent", type=float, default=0.5, help="latent loss weight")
# parser.add_argument("--lambda_kl", type=float, default=0.01, help="kullback-leibler loss weight")
# opt = parser.parse_args()

In [4]:
# use the libmolgrid ExampleProvider to obtain shuffled, balanced, and stratified batches from a file
e = molgrid.ExampleProvider(data_root=datadir+"/structs",cache_structs=False, balanced=True,shuffle=True)
# e.cache_structs=False
e.populate(fname)

ex = e.next()
c = ex.coord_sets[0]
center = tuple(c.center())

# initialize libmolgrid GridMaker
gmaker = molgrid.GridMaker()

# e.num_types()//2 is the number of channels used for voxel representation of docked ligand
print("Number of channels: ", e.num_types()//2)
dims = gmaker.grid_dimensions(e.num_types()//2)

mgridout = molgrid.MGrid4f(*dims)
gmaker.forward(center, c, mgridout.cpu())
molgrid.write_dx("tmp.dx", mgridout[0].cpu(), center, 0.5)

print("4D Tensor Shape: ", dims)
tensor_shape = (batch_size,)+dims
print(tensor_shape)
# molgrid.write_dx("temp",gmaker,center,1)

Number of channels:  14
4D Tensor Shape:  (14, 48, 48, 48)
(5, 14, 48, 48, 48)


In [5]:
# Initialize Generator, Enocoder, VAE and LR Discriminator on GPU
# generator = Generator(8, dims).to('cuda')
# encoder = Encoder(vaeLike=True).to('cuda')
# D_VAE = MultiDiscriminator(dims).to('cuda')
# D_LR = MultiDiscriminator(dims).to('cuda')

generator = Generator(8, dims)
encoder = Encoder(vaeLike=True)
D_VAE = MultiDiscriminator(dims)
D_LR = MultiDiscriminator(dims)

# Initialise weights
generator.apply(weights_init)
D_VAE.apply(weights_init)
D_LR.apply(weights_init)

# construct optimizers for the 4 networks
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=[0.5, 0.999])
optimizer_E = optim.Adam(encoder.parameters(), lr=0.0002, betas=[0.5, 0.999])
optimizer_D_VAE = optim.Adam(D_VAE.parameters(), lr=0.0002, betas=[0.5, 0.999])
optimizer_D_LR = optim.Adam(D_LR.parameters(), lr=0.0002, betas=[0.5, 0.999])

# construct input tensors
input_tensor1 = torch.zeros(tensor_shape, dtype=torch.float32)
input_tensor2 = torch.zeros(tensor_shape, dtype=torch.float32)
# p3d = (40,40,40,40,40,40)
# print(input_tensor.shape)
float_labels = torch.zeros(batch_size, dtype=torch.float32)

In [6]:
from torch.autograd import Variable
Tensor = torch.Tensor
def reparameterization(mu, logvar):
    std = torch.exp(logvar / 2)
    sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), 8))))
    z = sampled_z * std + mu
    return z

In [7]:
total_params = sum(p.numel() for p in generator.parameters())
train_params = sum(p.numel() for p in generator.parameters() if p.requires_grad)
print("Total Parameters: ", total_params)
print("Trainable Parameters: ", train_params)

Total Parameters:  2096110
Trainable Parameters:  2096110


In [8]:
# Loss functions
mae_loss = torch.nn.L1Loss()

In [9]:
# train for 500 iterations
losses = []
for iteration in range(5):
    # load data
    batch1 = e.next_batch(batch_size)
    batch2 = e.next_batch(batch_size)
    # libmolgrid can interoperate directly with Torch tensors, using views over the same memory.
    # internally, the libmolgrid GridMaker can use libmolgrid Transforms to apply random rotations and translations for data augmentation
    # the user may also use libmolgrid Transforms directly in python
    gmaker.forward(batch1, input_tensor1, 0, random_rotation=False)
    gmaker.forward(batch2, input_tensor2, 0, random_rotation=False)
    
    # Training the encoder and generator
    optimizer_E.zero_grad()
    optimizer_G.zero_grad()
    
    mu, logvar = encoder(input_tensor2)
    encoded_z = reparameterization(mu, logvar)
#     print("Latent space: ",encoded_z.shape)
    fake_ligands = generator(input_tensor1, encoded_z)
    print(input_tensor2.shape, fake_ligands.shape)
    loss_pixel = mae_loss(fake_ligands, input_tensor2)
    print(loss_pixel)
    loss_kl = 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - logvar - 1)
    print(loss_kl)
#     loss_VAE_GAN = D_VAE.compute_loss(fake_ligands, 1)
#     print(loss_VAE_GAN)
#     print(mu,logvar)
    
#     batch.extract_label(0, float_labels)
#     labels = float_labels.long().to('cuda')

#     optimizer.zero_grad()
#     output = model(input_tensor)
#     loss = F.cross_entropy(output,labels)
#     loss.backward()
#     optimizer.step()
#     losses.append(float(loss))

# plt.plot(losses)


d1 shape:  torch.Size([5, 32, 24, 24, 24])
d2 shape:  torch.Size([5, 64, 12, 12, 12])
d3 shape:  torch.Size([5, 64, 6, 6, 6])
d4 shape:  torch.Size([5, 128, 3, 3, 3])
u1 shape:  torch.Size([5, 128, 6, 6, 6])
u2 shape:  torch.Size([5, 128, 12, 12, 12])
u3 shape:  torch.Size([5, 96, 24, 24, 24])
torch.Size([5, 14, 48, 48, 48]) torch.Size([5, 14, 48, 48, 48])
tensor(0.0376, grad_fn=<L1LossBackward>)
tensor(0.0309, grad_fn=<MulBackward0>)
d1 shape:  torch.Size([5, 32, 24, 24, 24])
d2 shape:  torch.Size([5, 64, 12, 12, 12])
d3 shape:  torch.Size([5, 64, 6, 6, 6])
d4 shape:  torch.Size([5, 128, 3, 3, 3])
u1 shape:  torch.Size([5, 128, 6, 6, 6])
u2 shape:  torch.Size([5, 128, 12, 12, 12])
u3 shape:  torch.Size([5, 96, 24, 24, 24])
torch.Size([5, 14, 48, 48, 48]) torch.Size([5, 14, 48, 48, 48])
tensor(0.0375, grad_fn=<L1LossBackward>)
tensor(0.0309, grad_fn=<MulBackward0>)
d1 shape:  torch.Size([5, 32, 24, 24, 24])
d2 shape:  torch.Size([5, 64, 12, 12, 12])
d3 shape:  torch.Size([5, 64, 6, 6, 