In [1]:
import torch
import torch.nn as nn

import pickle
import os
import sys
import glob
from collections import OrderedDict

import numpy as np

sys.path.append('../../basenji/')

from basenji import params


In [2]:
root_dir = '/n/local/basenji'

In [3]:
params_file = os.path.join(root_dir,'manuscript/params.txt')
job = params.read_job_params(params_file, require=['seq_length','num_targets'])

{'link': 'exp_linear', 'batch_buffer': 4096, 'adam_beta1': 0.97, 'cnn_dilation': [1, 1, 1, 1, 1, 1, 2, 4, 8, 16, 32, 64, 1], 'adam_beta2': 0.98, 'cnn_pool': [1, 2, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1], 'loss': 'poisson', 'batch_size': 2, 'num_targets': 4229, 'cnn_dense': [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0], 'seq_length': 131072, 'cnn_filters': [312, 368, 435, 514, 607, 717, 108, 108, 108, 108, 108, 108, 1365], 'cnn_dropout': [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.05], 'cnn_filter_sizes': [22, 1, 6, 6, 6, 3, 3, 3, 3, 3, 3, 3, 1], 'learning_rate': 0.002063}


In [4]:
with open(os.path.join(root_dir,'manuscript/model_numpy.pkl'),'rb') as f:
    check_weights = pickle.load(f)

In [5]:
class BasenjiConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, 
                 dense_conv=False, padding_mode='zeros', use_batchnorm=True,
                 pooling_type='max', pooling_kernel=1, pooling_stride=1,
                 activation_type='relu', dropout_prob=0.5):
        super(BasenjiConvBlock, self).__init__()
        # Set 1D convolution
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride, padding, dilation, groups, bias,
                                padding_mode)
        # Set Batch Norm layer
        if use_batchnorm:
            self.batch_normalization = nn.BatchNorm1d(out_channels, eps=1e-05, 
                                                      momentum=0.1, affine=True, 
                                                      track_running_stats=True)
        else:
            self.batch_normalization = Passthrough()
        # Add pooling
        if pooling_type == 'max':
            self.pooling = nn.MaxPool1d(pooling_kernel, pooling_stride)
        else:
            print('Unrecognized pooling type "{}"'.format(pooling_type), file=sys.stderr)
            exit(1)
        # Add Nonlinearity
        if activation_type == 'relu':
            self.activation = nn.ReLU()
        elif activation_type == 'exp_linear':
            self.activation = ExpLinear()
        elif activation_type == 'none':
            self.activation = Passthrough()
        else:
            print('Unrecognized activation type "{}"'.format(activation_type), file=sys.stderr)
        # Add dropout
        if dropout_prob > 0.0:
            self.dropout = Dropout1d(p=dropout_prob)
        else:
            self.dropout = Passthrough()
        # Decide if dense conv
        self.dense_conv = dense_conv
    def forward(self, input):
        hook = self.conv1d( input )
        hook = self.batch_normalization( hook )
        hook = self.pooling( hook )
        hook = self.activation( hook )
        hook = self.dropout( hook )
        if self.dense_conv:
            hook = torch.cat([input,hook],dim=1)
        return hook
    
class Dropout1d(nn.Module):
    def __init__(self, p=0.5):
        super(Dropout1d, self).__init__()
        self.dropout = nn.Dropout2d(p)
    def forward(self, input):
        return self.dropout( input.unsqueeze(-1) ).squeeze(3)

class ExpLinear(nn.Module):
    def __init__(self):
        super(ExpLinear, self).__init__()
    def forward(self, input):
        return input.clamp(min=0.) + \
               input.clamp(min=-50., max=0.).exp()

class Passthrough(nn.Module):
    def __init__(self):
        super(Passthrough, self).__init__()
    def forward(self, input):
        return input


In [6]:
class SeqNN(nn.Module):
    def __init__(self, job):
        super(SeqNN, self).__init__()
        network = []
        density = 0
        input_feat = 4
        for i in range(len(job['cnn_filter_sizes'])):
            curr_feat = job['cnn_filters'][i]
            dilation  = job['cnn_dilation'][i]
            kern_size = job['cnn_filter_sizes'][i]
            drop_prob = job['cnn_dropout'][i]
            pooling   = job['cnn_pool'][i]
            is_dense  = job['cnn_dense'][i] == 1
            if is_dense:
                density += job['cnn_filters'][i-1]
            network.append(
                ('cnn{}'.format(i), BasenjiConvBlock(input_feat, curr_feat, kern_size,
                                                      stride=1, padding=0, dilation=dilation, 
                                                      groups=1, bias=False, dense_conv=is_dense, 
                                                      pooling_kernel=pooling, pooling_stride=pooling, 
                                                      activation_type='relu', dropout_prob=drop_prob
                                                     ) )
            )
            input_feat = curr_feat + density


        network.append(
            ('final', BasenjiConvBlock(curr_feat, job['num_targets'], 1, stride=1, 
                                       padding=0, dilation=1, groups=1, bias=True, dense_conv=False,
                                       use_batchnorm=False, pooling_kernel=1, pooling_stride=1, 
                                       activation_type='exp_linear', dropout_prob=0.0
                                      ))
        )
        self.network = nn.Sequential( OrderedDict(network) )
    
    def forward(self, input):
        self.network(input)



In [7]:
Basenji = SeqNN(job)

In [8]:
Basenji.network

Sequential(
  (cnn0): BasenjiConvBlock(
    (conv1d): Conv1d(4, 312, kernel_size=(22,), stride=(1,), bias=False)
    (batch_normalization): BatchNorm1d(312, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pooling): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
    (activation): ReLU()
    (dropout): Dropout1d(
      (dropout): Dropout2d(p=0.05)
    )
  )
  (cnn1): BasenjiConvBlock(
    (conv1d): Conv1d(312, 368, kernel_size=(1,), stride=(1,), bias=False)
    (batch_normalization): BatchNorm1d(368, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pooling): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (activation): ReLU()
    (dropout): Dropout1d(
      (dropout): Dropout2d(p=0.05)
    )
  )
  (cnn2): BasenjiConvBlock(
    (conv1d): Conv1d(368, 435, kernel_size=(6,), stride=(1,), bias=False)
    (batch_normalization): BatchNorm1d(435, eps=1e-05, momentum=0.1, affine=True, track_running

In [9]:
for i in range(13):
    new_weights = torch.Tensor(check_weights['cnn{}/conv1d/kernel:0'.format(i)]).permute(2,1,0)
    new_gamma   = torch.Tensor(check_weights['cnn{}/batch_normalization/gamma:0'.format(i)])
    new_beta    = torch.Tensor(check_weights['cnn{}/batch_normalization/beta:0'.format(i)])
    new_mean    = torch.Tensor(check_weights['cnn{}/batch_normalization/moving_mean:0'.format(i)])
    new_var     = torch.Tensor(check_weights['cnn{}/batch_normalization/moving_variance:0'.format(i)])
    getattr( Basenji.network, 'cnn{}'.format(i) ).conv1d.weight.data = new_weights
    getattr( Basenji.network, 'cnn{}'.format(i) ).batch_normalization.weight.data = new_gamma
    getattr( Basenji.network, 'cnn{}'.format(i) ).batch_normalization.bias.data = new_beta
    getattr( Basenji.network, 'cnn{}'.format(i) ).batch_normalization.running_mean.data = new_mean
    getattr( Basenji.network, 'cnn{}'.format(i) ).batch_normalization.running_var.data = new_var
    getattr( Basenji.network, 'cnn{}'.format(i) ).batch_normalization.num_batches_tracked.data = torch.tensor(10000, dtype=torch.long)

new_weights = torch.Tensor(check_weights['final/dense/kernel:0']).permute(1,0).unsqueeze(2)
new_bias    = torch.Tensor(check_weights['final/dense/bias:0'])
getattr( Basenji.network, 'final' ).conv1d.weight.data = new_weights
getattr( Basenji.network, 'final' ).conv1d.bias.data = new_bias