In [1]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [13]:
from data import *
from data.base_dataset import *
from data.pdbbind_dataset import *
from model import *

In [14]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [21]:
path = Path('../../data')

In [34]:
input_nc = 24
bs = 16

tfms = transforms.Compose([Center(),
                           Rotate(90),
                           Channel(['C'], input_nc, 1.0, 1.4),
                           Channel(['O'], input_nc, 1.0, 1.4),
                           Channel(['N'], input_nc, 1.0, 1.4),
                           ToTensor()])
class Option:
    csvfile = path/'refined_set.csv'
    dataroot = '../../../2018/refined-set/'
    channels = 'cno'
    grid_size = 20
    grid_spacing = 1
    rvdw = 1.4
    filter_kd = True
    rotate = 10

opt = Option()
ds = PdbBindDataset()
ds.initialize(opt)
dl = torch.utils.data.DataLoader(ds, batch_size=bs,
                                 shuffle=True, num_workers=0)

In [37]:
data = next(iter(dl))

In [45]:
np.sum(data['grids'][0,0].cpu().numpy(), axis=2)

array([[  9.32587341e-15,   9.32587341e-15,   9.32587341e-15,
          9.32587341e-15,   9.32587341e-15,   9.32587341e-15,
          9.32587341e-15,   9.32587341e-15,   9.32587341e-15,
          9.32587341e-15,   9.32587341e-15,   9.32587341e-15,
          9.32587341e-15,   9.32587341e-15,   9.32587341e-15,
          9.32587341e-15,   9.32587341e-15,   9.32587341e-15,
          9.32587341e-15,   9.32587341e-15,   9.32587341e-15],
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
      

In [32]:
model = SqueezeNet(input_nc=6)
model.to(device)

SqueezeNet(
  (features): Sequential(
    (0): Conv3d(6, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2))
    (1): ReLU(inplace)
    (2): MaxPool3d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): Fire(
      (squeeze): Conv3d(64, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (squeeze_activation): ReLU(inplace)
      (expand1x1): Conv3d(16, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (expand1x1_activation): ReLU(inplace)
      (expand3x3): Conv3d(16, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (expand3x3_activation): ReLU(inplace)
    )
    (4): Fire(
      (squeeze): Conv3d(128, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (squeeze_activation): ReLU(inplace)
      (expand1x1): Conv3d(16, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (expand1x1_activation): ReLU(inplace)
      (expand3x3): Conv3d(16, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (expand3x3_activation): ReLU(inplace)
    )
    (

In [29]:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [57]:
for param_group in optimizer.param_groups:
    param_group['lr'] = 0.0001

In [None]:
for epoch in range(10):
    running_loss = 0
    for i, data in enumerate(dl, 0):
        grids = data['grids'].to(device)
        affinities = data['affinity'].to(device)
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(grids)
        loss = criterion(outputs, affinities)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20 == 19:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 20))
            running_loss = 0.0

[1,    20] loss: 42.319
[1,    40] loss: 45.145
[1,    60] loss: 45.195
[1,    80] loss: 38.745
[1,   100] loss: 42.507
[1,   120] loss: 42.891
[2,    20] loss: 45.096
[2,    40] loss: 40.623
[2,    60] loss: 42.159
[2,    80] loss: 40.413
[2,   100] loss: 43.612
[2,   120] loss: 46.234
[3,    20] loss: 45.709
[3,    40] loss: 38.737
[3,    60] loss: 41.660
[3,    80] loss: 49.065
[3,   100] loss: 39.815
[3,   120] loss: 45.227
[4,    20] loss: 43.844
[4,    40] loss: 43.751
[4,    60] loss: 45.393
[4,    80] loss: 42.424
[4,   100] loss: 43.655
[4,   120] loss: 39.146
[5,    20] loss: 40.133
[5,    40] loss: 43.738
[5,    60] loss: 46.688
[5,    80] loss: 44.534
[5,   100] loss: 42.289
[5,   120] loss: 40.027
[6,    20] loss: 42.493
[6,    40] loss: 45.166
[6,    60] loss: 42.097
[6,    80] loss: 43.195
[6,   100] loss: 41.787
[6,   120] loss: 47.753
[7,    20] loss: 46.517
[7,    40] loss: 41.264
[7,    60] loss: 39.224
[7,    80] loss: 44.188
[7,   100] loss: 45.179
[7,   120] loss:

In [55]:
outputs

tensor([[10.8296],
        [10.4283],
        [ 9.4874],
        [ 9.2714],
        [ 9.8149]], device='cuda:1', grad_fn=<AddmmBackward>)

In [56]:
affinities

tensor([[ 8.5685],
        [ 4.7105],
        [ 2.8134],
        [ 5.8091],
        [16.1181]], device='cuda:1')