In [None]:
# default_exp nets

# nets
> The neural network architects.

Our implementation of the networks will follow this architecture:

![](https://github.com/shenwanxiang/bidd-molmap/blob/master/paper/images/net.png?raw=1)

In [None]:
#export
import torch
from torch import nn
import torch.nn.functional as F

## Feature extraction

When building complex networks it's better to build and test the smaller components first, then combine them together. This way we can also reuse the individual parts easily.

### Convolutional block

This block takes the descriptor or  fingerprint maps as input, and returns outputs of a max pooling layer.

- Descriptor: `13*37*37` -> `48*37*37` -> `48*19*19`
- Fingerprint: `3*37*36` -> `48*37*36` -> `48*19*18`

In [None]:
#export
class Convnet(nn.Module):
    "Convolutional feature extraction Block"
    def __init__(self, C_in=13, C_out=48, conv_size=13):
        super(Convnet, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(C_in, C_out, kernel_size=conv_size, stride=1, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
    def forward(self, x):        
        
        return self.conv(x)

Let's test it on the descriptor and fingerprint maps

In [None]:
convnet = Convnet()

i = torch.rand((10, 13, 37, 37))
o = convnet(i)
o.shape

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


torch.Size([10, 48, 19, 19])

In [None]:
convnet = Convnet(3, 48)

i = torch.rand((10, 3, 37, 36))
o = convnet(i)
o.shape

torch.Size([10, 48, 19, 18])

### Inception block

After the convolutional block, the resulting feature maps will further pass through some inception blocks. 

The inceptions implemented here are the naïve Google inceptions. It passes the input through multiple convolutional layers and then concatenate the output. This inception block is actually two smaller inception blocks bridged with a max pooling layer. First the small inception block:

- Descriptor: `48*19*19` -> 3 outputs of `32*19*19` -> `96*19*19`, |-> `96*10*10` -> 3 outputs of `64*10*10` -> `192*10*10`
- Fingerprint: `48*19*18` -> 3 outputs of `32*19*18` -> `96*19*18`, |-> `96*10*9` -> 3 outputs of `64*10*9` -> `192*10*9`



In [None]:
#export
class Inception(nn.Module):
    "Naive Google Inception Block"
    def __init__(self, C_in=48, C_out=32, stride=1):
        super(Inception, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(C_in, C_out, kernel_size=5, stride=stride, padding='same'),
            nn.ReLU(),
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(C_in, C_out, kernel_size=3, stride=stride, padding='same'),
            nn.ReLU(),
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(C_in, C_out, kernel_size=1, stride=stride, padding='same'),
            nn.ReLU(),
        )
        
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        
        return torch.cat((x1, x2, x3), dim=1)

In [None]:
inception = Inception()

i = torch.rand((10, 48, 19, 19))
o = inception(i)
o.shape

torch.Size([10, 96, 19, 19])

In [None]:
inception = Inception(96, 64)

i = torch.rand((10, 96, 10, 10))
o = inception(i)
o.shape

torch.Size([10, 192, 10, 10])

In [None]:
inception = Inception()

i = torch.rand((10, 48, 19, 18))
o = inception(i)
o.shape

torch.Size([10, 96, 19, 18])

In [None]:
inception = Inception(96, 64)

i = torch.rand((10, 96, 10, 9))
o = inception(i)
o.shape

torch.Size([10, 192, 10, 9])


And the double inception block:


In [None]:
#export
class DoubleInception(nn.Module):
    "Double Inception Block"
    def __init__(self, C_in1=48, C_out1=32, stride1=1, C_in2=96, C_out2=64, stride2=1):
        super(DoubleInception, self).__init__()
        
        self.inception1 = Inception(C_in1, C_out1, stride1)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.inception2 = Inception(C_in2, C_out2, stride2)
        
    def forward(self, x):
        x = self.inception1(x)
        x = self.maxpool(x)
        x = self.inception2(x)
        
        return x

In [None]:
double_inception = DoubleInception()

i = torch.rand((10, 48, 19, 19))
o = double_inception(i)
o.shape

torch.Size([10, 192, 10, 10])

In [None]:
double_inception = DoubleInception()

i = torch.rand((10, 48, 19, 18))
o = double_inception(i)
o.shape

torch.Size([10, 192, 10, 9])

### Global max pooling

There is no global max pooling layer in PyTorch but this is very easy to realise.

In [None]:
i = torch.rand((10, 192, 10, 10))
o = i.amax(dim=(-1, -2))
o.shape

torch.Size([10, 192])

### Fully connected block

At the end of the network the data passes through several fully connected layers. 

If the MolMap network is single path:

- `192` -> `128` -> `32`

And if double path:

- `384` -> `256` -> `128` -> `32`

In [None]:
#export
class SinglePathFullyConnected(nn.Module):
    "Fully connected layers for single path MolMap nets"
    def __init__(self, C1=192, C2=128, C3=32):
        super(SinglePathFullyConnected, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(C1, C2),
            nn.ReLU(),
            nn.Linear(C2, C3)
        )
        
    def forward(self, x):
        return self.fc(x)

In [None]:
single_path_fully_connected = SinglePathFullyConnected()

i = torch.rand((10, 192))
o = single_path_fully_connected(i)
o.shape

torch.Size([10, 32])

In [None]:
#export
class DoublePathFullyConnected(nn.Module):
    "Fully connected layers for double paths MolMap nets"
    def __init__(self, C1=384, C2=256, C3=128, C4=32):
        super(DoublePathFullyConnected, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(C1, C2),
            nn.ReLU(),
            nn.Linear(C2, C3),
            nn.ReLU(),        
            nn.Linear(C3, C4),
        )
        
    def forward(self, x):
        return self.fc(x)

In [None]:
double_path_fully_connected = DoublePathFullyConnected()

i = torch.rand((10, 384))
o = double_path_fully_connected(i)
o.shape

torch.Size([10, 32])

## Single Path Molecular Mapping network

Descriptor map or Fingerprint map only. The two feature maps use identical network structures and only differ in data shape. Note that we need to specify the number of channels for the feature maps when initialising the model, but the model should be able to handle feature maps with different dimensions.

- descriptor: `13*37*37` -> `32`
- fingerprint: `3*37*36` -> `32`

The output layer is not included.

In [None]:
#export
class SinglePathMolMapNet(nn.Module):
    "Single Path Molecular Mapping Network"
    def __init__(self, conv_in=13, conv_size=13, FC=[128, 32]):
        super(SinglePathMolMapNet, self).__init__()
        
        # output channels in the double inception
        C_out1, C_out2 = 32, 64
        
        self.conv = Convnet(C_in=conv_in, C_out=48, conv_size=conv_size)
        self.double_inception = DoubleInception(C_in1=48, C_out1=C_out1, C_in2=C_out1*3, C_out2=C_out2)
        self.fully_connected = SinglePathFullyConnected(C1=C_out2*3, C2=FC[0], C3=FC[1])
        
    def forward(self, x):
        x = self.conv(x)
        x = self.double_inception(x)
        x = x.amax(dim=(-1, -2))
        x = self.fully_connected(x)
        
        return x

In [None]:
single_path = SinglePathMolMapNet()

i = torch.rand((10, 13, 37, 37))
o = single_path(i)
o.shape

torch.Size([10, 32])

In [None]:
single_path = SinglePathMolMapNet(conv_in=3)

i = torch.rand((10, 3, 37, 36))
o = single_path(i)
o.shape

torch.Size([10, 32])

## Double Path Molecular Mapping network

Both the descriptor map and Fingerprint map will pass through the convolutional block, then the double inception block, and their results are then combined, before finally pass through the fully connected layers. 

After convolutional and double inception block:

- descriptor: `13*37*37` -> `192*10*10`
- fingerprint: `3*37*36` -> `192*10*9`

After global max pooling:

- descriptor: `192*10*10` -> `192`
- fingerprint: `192*10*9` -> `192`

After Concatenation and fully connected blocks:

- `192 + 192` -> `384` -> `32`

The output layer is not included.

In [None]:
#export
class DoublePathMolMapNet(nn.Module):
    "Double Path Molecular Mapping Network"
    def __init__(self, conv_in1=13, conv_in2=3, conv_size=13, FC=[256, 128, 32]):
        super(DoublePathMolMapNet, self).__init__()
        
        # output channels in the double inception
        C_out1, C_out2 = 32, 64
        
        self.conv1 = Convnet(C_in=conv_in1, C_out=48, conv_size=conv_size)        
        self.conv2 = Convnet(C_in=conv_in2, C_out=48, conv_size=conv_size)
        self.double_inception = DoubleInception(C_in1=48, C_out1=C_out1, C_in2=C_out1*3, C_out2=C_out2)        
        self.fully_connected = DoublePathFullyConnected(C1=C_out2*6, C2=FC[0], C3=FC[1], C4=FC[2])
                
    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x1 = self.double_inception(x1)
        x1 = x1.amax(dim=(-1, -2))
        
        x2 = self.conv2(x2)
        x2 = self.double_inception(x2)
        x2 = x2.amax(dim=(-1, -2))
        
        x = torch.cat((x1, x2), dim=1)
        x = self.fully_connected(x)
        
        return x

In [None]:
double_path = DoublePathMolMapNet()

i1 = torch.rand((10, 13, 37, 37))
i2 = torch.rand((10, 3, 37, 36))
o = double_path(i1, i2)
o.shape

torch.Size([10, 32])

## Resnet block

Currently not used

In [None]:
#export
class Resnet(nn.Module):
    "Naive Google Inception Block"
    def __init__(self, C, conv_size):
        super(Resnet, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(C, C, kernel_size=conv_size, stride=1, padding='same'),
            nn.BatchNorm2d(C),
            nn.ReLU()
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(C, C, kernel_size=conv_size, stride=1, padding='same'),
            nn.BatchNorm2d(C)
        )
        
    def forward(self, x):
        o = self.conv1(x)
        o = self.conv2(o)
        o += x
        
        return F.relu(o)

In [None]:
resnet = Resnet(48, 5)

i = torch.rand((10, 48, 19, 18))
o = resnet(i)
o.shape

torch.Size([10, 48, 19, 18])