In [2]:
!git clone https://github.com/PraljakReps/ProtWaveVAE_model.git

Cloning into 'ProtWaveVAE_model'...
remote: Enumerating objects: 47, done.[K
remote: Counting objects: 100% (47/47), done.[K
remote: Compressing objects: 100% (38/38), done.[K
remote: Total 47 (delta 20), reused 33 (delta 9), pack-reused 0 (from 0)[K
Receiving objects: 100% (47/47), 1.81 MiB | 19.47 MiB/s, done.
Resolving deltas: 100% (20/20), done.


In [11]:
!cd ProtWaveVAE_model

In [14]:
import os
os.chdir("ProtWaveVAE_model")
os.getcwd()

'/storage/ice1/6/9/khari8/ProtWaveVAE_model'

In [15]:
import torch
import numpy as np
from ProtWave_VAE import model_components, wavenet_decoder, model_ensemble

X = torch.tensor([0, 2, 1, 0]).unsqueeze(0)
# get the number of unique cats
num_categories = len(torch.unique(X))
x_one_hot = torch.nn.functional.one_hot(X, num_categories).float()
print('Input tensor and size:', x_one_hot, x_one_hot.shape)

Input tensor and size: tensor([[[1., 0., 0.],
         [0., 0., 1.],
         [0., 1., 0.],
         [1., 0., 0.]]]) torch.Size([1, 4, 3])


In [16]:
# encoder hyperparameters
batch_size = x_one_hot.shape[0] # batch size
seq_len = x_one_hot.shape[1] # length of the input sequence
class_labels = num_categories # numer of categorical labels
z_dim = 3 # latent space size
C_in = class_labels
C_out = 128 # convolution layer hidden kernel number
kernel_size = 3 # kernel size for encoder
num_fc = 2 # fully connect layers before embedding latent codes
num_rates = 0 # depth of the encoder convolutions (set to 0 for max depth)

In [17]:
# encoder comp: q(z|x)
encoder = model_components.GatedCNN_encoder(
        protein_len=seq_len,
        class_labels=class_labels,
        z_dim=z_dim,
        num_rates=num_rates,
        alpha=0.1, # leaky ReLU hparam
        kernel=kernel_size,
        num_fc=num_fc,
        C_in=C_in,
        C_out=C_out
)

In [18]:
# wavenet decoder hyperparameters
device = 'cpu' # device = 'cuda' # if GPU is available
whs = 32 # dilated convolution kernel number
hhs = 256 # top model hidden representation size
dec_kernel_size = 3
ndr = 5 # number of dilations (i.e. wavenet depth)

In [19]:
# decoder comp: p(x|z)
decoder_xr = wavenet_decoder.Wave_generator(
                        protein_len=seq_len,
                        class_labels=class_labels,
                        DEVICE=device,
                        wave_hidden_state=whs,
                        head_hidden_state=hhs,
                        num_dil_rates=ndr,
                        kernel_size=dec_kernel_size
)

In [20]:
# latent scaling: z -> Z
latent_upscaler = wavenet_decoder.CondNet(
        z_dim=z_dim,
        output_shape=(1,seq_len)
)

In [21]:
ProtWaveVAE_model = model_ensemble.ProtWaveVAE(
        DEVICE=device,
        encoder=encoder,
        decoder_recon=decoder_xr,
        cond_mapper=latent_upscaler,
        z_dim=z_dim
)

In [22]:
ProtWaveVAE_model.eval()
with torch.no_grad():
    logits_xrc, z, z_mu, z_var = ProtWaveVAE_model(x=x_one_hot)

print('Predicted logits and size:', logits_xrc, logits_xrc.shape)
print('Inferred latent embeddings and size:', z, z.shape)

Predicted logits and size: tensor([[[-0.0186, -0.1138,  0.0021],
         [ 0.1117,  0.0497, -0.2652],
         [ 0.0785, -0.0719, -0.3013],
         [ 0.2561, -0.0243, -0.5685]]]) torch.Size([1, 4, 3])
Inferred latent embeddings and size: tensor([[ 0.1000,  0.2000, -0.4889]]) torch.Size([1, 3])


In [None]:
# next steps: train model on their data
# train model on new data
