## Play with loading in weights to a keras model and then exporting to a PyTorch model

In [None]:
import sys
# point path to genesis repo
sys.path.insert(
    0,
    '/gpfs/commons/home/tchen/al_project/genesis/analysis/splicing'
)

In [None]:
from definitions.generator.splirent_deconv_conv_generator_concat import load_generator_network
from genesis.generator import build_generator, st_sampled_softmax, st_hardmax_softmax
from pathlib import Path
from keras.models import load_model
from keras import backend as K
import numpy as np

In [None]:
model_name = 'genesis_splicing_cnn_target_isoform_00_pwm_and_multisample_hek_only_random_regions_50_epochs_harderentropy_generator.h5'
model_save_dir = '/gpfs/commons/groups/knowles_lab/ting/DEN_splicing_pretrained_models/'

full_path = model_save_dir + model_name

In [None]:
generator_model = load_model(filepath=str(full_path), custom_objects={'K': K, 'st_sampled_softmax': st_sampled_softmax, 'st_hardmax_softmax': st_hardmax_softmax}, compile=False)

In [None]:
generator_model.outputs[5]

In [None]:
generator_model.summary()

Get order of weights

In [None]:
generator_model.get_config()['layers'][12]

In [None]:
len(generator_model.get_weights())

Got the weights!

Let's save the weights into a numpy object file

In [None]:
save_path = '/gpfs/commons/groups/knowles_lab/ting/DEN_splicing_generator_weights/'
save_name = 'target_isoform_00.npy'

np.save(save_path+save_name, np.array(generator_model.get_weights(), dtype=object), allow_pickle=True)

In [None]:
len(generator_model.get_weights())

Now we're gonna try loading in the weights into our PyTorch model

In [1]:
import sys
# point path to our repo
sys.path.insert(
    0,
    '/gpfs/commons/home/tchen/al_project/active-learning-cnns-gps/src'
)

import numpy as np
from models.den import Generator
import torch
from data.old_dataset import create_sequence_templates
from torch import nn
from models.base_cnn import OracleCNN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
save_path = '/gpfs/commons/groups/knowles_lab/ting/DEN_splicing_generator_weights'
save_name = 'target_isoform_00.npy'

target_isoform_net_weights = np.load(save_path+save_name, allow_pickle=True)

# list of numpy arrays
weights = target_isoform_net_weights.tolist()

In [3]:
device = torch.device('cpu')
# embedding_template, embedding_mask = create_sequence_templates()

# need to extract embedding template and mask from pretrained network
pretrained_embedding_template = torch.tensor(weights[-2]).to(device)
pretrained_embedding_mask = torch.tensor(weights[-1]).to(device)

n_samples = 10
n_classes = 1
seq_length = 109
latent_dim = 100
batch_size = 32

In [None]:
pretrained_embedding_template.reshape(109, 4)

In [4]:
torch_generator = Generator(embedding_template=pretrained_embedding_template,
                                   embedding_mask=pretrained_embedding_mask,
                                   device=device,
                                   latent_dim=latent_dim,
                                   batch_size=batch_size,
                                   seq_length=seq_length,
                                   n_classes=n_classes,
                                   n_samples=n_samples)

In [None]:
torch_generator.generator_network.generator_network

In [5]:
running_keras_weight_idx = 0

for i, layer in enumerate(torch_generator.generator_network.generator_network):
    if isinstance(layer, nn.Linear):
        # transfer linear layer weights and biases
        layer.weight.data = torch.from_numpy(weights[running_keras_weight_idx].T)
        running_keras_weight_idx += 1

        layer.bias.data = torch.from_numpy(weights[running_keras_weight_idx])
        running_keras_weight_idx += 1
    elif isinstance(layer, nn.ConvTranspose2d):
        # transfer convtranspose2d weights and biases
        layer.weight.data = torch.from_numpy(np.transpose(weights[running_keras_weight_idx], axes=[3, 2, 0, 1]))
        running_keras_weight_idx += 1
        
        # TODO: double check if we need to reverse biases
        layer.bias.data = torch.from_numpy(weights[running_keras_weight_idx])
        running_keras_weight_idx += 1
    elif isinstance(layer, nn.Conv2d):
        # transfer conv2d weights and biases
        layer.weight.data = torch.from_numpy(np.transpose(weights[running_keras_weight_idx], axes=[3, 2, 0, 1]))
        running_keras_weight_idx += 1
        
        layer.bias.data = torch.from_numpy(weights[running_keras_weight_idx])
        running_keras_weight_idx += 1
    elif isinstance(layer, nn.BatchNorm2d):
        # transfer batch norm gamma, beta, running mean, running variance
        # order from keras model should be gamma, beta, moving mean, moving variance
        layer.weight.data = torch.from_numpy(weights[running_keras_weight_idx])
        running_keras_weight_idx += 1
        
        layer.bias.data = torch.from_numpy(weights[running_keras_weight_idx])
        running_keras_weight_idx += 1
        
        layer.running_mean.data = torch.from_numpy(weights[running_keras_weight_idx])
        running_keras_weight_idx += 1
        
        layer.running_var.data = torch.from_numpy(weights[running_keras_weight_idx])
        running_keras_weight_idx += 1

In [None]:
# list of numpy arrays
for name, param in torch_generator.named_parameters():
    print(name, param.shape)

Now that we've loaded the model, let's test it

In [None]:
# load in predictor and look at predictions
oracle_save_path = '/gpfs/commons/home/tchen/al_project/active-learning-save/saved_metrics/models/base_cnn_oracle.pt'

oracle = OracleCNN().to(device)
oracle.load_state_dict(torch.load(oracle_save_path))
for param in oracle.parameters():
    param.requires_grad = False
oracle.eval()

In [6]:
torch_generator.to(device)

Generator(
  (generator_network): GeneratorNetwork(
    (generator_network): ModuleList(
      (0): Linear(in_features=101, out_features=3456, bias=True)
      (1): ConvTranspose2d(384, 256, kernel_size=(7, 1), stride=(2, 1))
      (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ConvTranspose2d(256, 192, kernel_size=(8, 1), stride=(2, 1))
      (4): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ConvTranspose2d(192, 128, kernel_size=(7, 1), stride=(2, 1))
      (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): Conv2d(128, 128, kernel_size=(8, 1), stride=(1, 1), padding=same)
      (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): Conv2d(128, 64, kernel_size=(8, 1), stride=(1, 1), padding=same)
      (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): Conv2

In [7]:
torch_generator.eval()
samples = torch_generator()

  self.padding, self.dilation, self.groups)


In [8]:
samples[-1].reshape(32, 109, 4)[0]

tensor([[10., -4., -4., -4.],
        [-4., -4., 10., -4.],
        [-4., -4., 10., -4.],
        [-4., -4., -4., 10.],
        [-4., -4., 10., -4.],
        [-4., 10., -4., -4.],
        [-4., -4., -4., 10.],
        [-4., -4., -4., 10.],
        [-4., -4., 10., -4.],
        [-4., -4., 10., -4.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 

In [9]:
samples[0]

tensor([[[[[1.],
           [0.],
           [0.],
           [0.]],

          [[0.],
           [0.],
           [1.],
           [0.]],

          [[0.],
           [0.],
           [1.],
           [0.]],

          ...,

          [[1.],
           [0.],
           [0.],
           [0.]],

          [[0.],
           [0.],
           [1.],
           [0.]],

          [[1.],
           [0.],
           [0.],
           [0.]]],


         [[[1.],
           [0.],
           [0.],
           [0.]],

          [[0.],
           [0.],
           [1.],
           [0.]],

          [[0.],
           [0.],
           [1.],
           [0.]],

          ...,

          [[1.],
           [0.],
           [0.],
           [0.]],

          [[0.],
           [0.],
           [1.],
           [0.]],

          [[1.],
           [0.],
           [0.],
           [0.]]],


         [[[1.],
           [0.],
           [0.],
           [0.]],

          [[0.],
           [0.],
           [1.],
   

In [None]:
samples[-2][0].reshape(109, 4)

In [None]:
samples[0].shape

In [10]:
# save generated samples
save_path = '/gpfs/commons/groups/knowles_lab/ting/'
save_name = 'pytorch_generated_sequences_target_isoform_00.npy'

np.save(save_path+save_name, samples[0].detach().cpu().numpy(), allow_pickle=True)

In [11]:
# also save pwm
save_path = '/gpfs/commons/groups/knowles_lab/ting/'
save_name = 'pytorch_optimized_pwm_target_isoform_00.npy'

np.save(save_path+save_name, samples[2].detach().cpu().numpy(), allow_pickle=True)

Double check length of data

In [None]:
PATH_TO_DIRECTORY = '/gpfs/commons/home/tchen/al_project/active-learning-cnns-gps'
dataset_path = PATH_TO_DIRECTORY + '/old_data/5SS_compressed.txt'
seq_len = 101
n = 265137
inputs = np.zeros((n, seq_len, 4))
prob_s1 = np.zeros(n)

with open(dataset_path) as f:
    ind = 0
    for line in f:
        mod_line = line.split('\t')
        print(mod_line)