In [1]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

In [11]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [4]:
from data import *

In [6]:
path = Path('../data')

In [7]:
pdbbind_dataset = PdbBindDataset(csvfile=path/'refined_set.csv',
                                 rootdir='../../2018/refined-set/',
                                 filter_kd=True)
sample = pdbbind_dataset[0]

In [9]:
tfms = transforms.Compose([Center(),
                           Rotate(90),
                           Channel(['C'], 20, 1.0, 1.4),
                           Channel(['O'], 20, 1.0, 1.4),
                           Channel(['N'], 20, 1.0, 1.4),
                           ToTensor()])
ds = PdbBindDataset(csvfile=path/'refined_set.csv',
                    rootdir='../../2018/refined-set/',
                    filter_kd=True,
                    transform=tfms)

In [134]:
class Fire(nn.Module):
    def __init__(self, inplanes, squeeze_planes,
                 expand1x1_planes, expand3x3_planes):
        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv3d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
        self.expand1x1 = nn.Conv3d(squeeze_planes, expand1x1_planes,
                                   kernel_size=1)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv3d(squeeze_planes, expand3x3_planes,
                                   kernel_size=3, padding=1)
        self.expand3x3_activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.squeeze_activation(self.squeeze(x))
        return torch.cat([
            self.expand1x1_activation(self.expand1x1(x)),
            self.expand3x3_activation(self.expand3x3(x))
        ], 1)

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
    
class SqueezeNet(nn.Module):
    def __init__(self, input_nc):
        super().__init__()
        output_nc = 64
        features = [nn.Conv3d(input_nc, output_nc, kernel_size=7, stride=2),
                    nn.ReLU(inplace=True),
                    nn.MaxPool3d(kernel_size=3, stride=2, ceil_mode=True),
                    Fire(64, 16, 64, 64),
                    Fire(128, 16, 64, 64),
                    nn.MaxPool3d(kernel_size=3, stride=2, ceil_mode=True),
                    Fire(128, 32, 128, 128),
                    Fire(256, 32, 128, 128),
                    nn.MaxPool3d(kernel_size=3, stride=2, ceil_mode=True),
                    Fire(256, 48, 192, 192),
                    Fire(384, 48, 192, 192),
                    Fire(384, 64, 256, 256),
                    Fire(512, 64, 256, 256)]
        
        head = [Flatten(),
                nn.Dropout(p=0.5),
                nn.Linear(512, 128),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(128),
                nn.Dropout(p=0.5),
                nn.Linear(128, 1)
                ]
        
        self.features = nn.Sequential(*features)
        self.head = nn.Sequential(*head)
    
    def forward(self, x):
        x = self.features(x)
        x = self.head(x)
        return x

In [135]:
model = SqueezeNet(input_nc=6)

In [136]:
from torchsummary import summary
summary(model, input_size=(6, 24, 24, 24))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1          [-1, 64, 9, 9, 9]         131,776
              ReLU-2          [-1, 64, 9, 9, 9]               0
         MaxPool3d-3          [-1, 64, 4, 4, 4]               0
            Conv3d-4          [-1, 16, 4, 4, 4]           1,040
              ReLU-5          [-1, 16, 4, 4, 4]               0
            Conv3d-6          [-1, 64, 4, 4, 4]           1,088
              ReLU-7          [-1, 64, 4, 4, 4]               0
            Conv3d-8          [-1, 64, 4, 4, 4]          27,712
              ReLU-9          [-1, 64, 4, 4, 4]               0
             Fire-10         [-1, 128, 4, 4, 4]               0
           Conv3d-11          [-1, 16, 4, 4, 4]           2,064
             ReLU-12          [-1, 16, 4, 4, 4]               0
           Conv3d-13          [-1, 64, 4, 4, 4]           1,088
             ReLU-14          [-1, 64, 