In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pdb

In [34]:
def conv3x3(ni, no, ks=3, s=2, p=1):
    # reduces the image size by a factor of 2
    return nn.Conv2d(ni, no, kernel_size=3, stride=s, padding=p)

def conv1x1(ni, no, ks=1, s=1, p=0):
    return nn.Conv2d(ni, no, kernel_size=ks, stride=s, padding=p)

class ConvBlock(nn.Module):
    '''
    ConvBlock: BatchNorm -> ReLU -> Conv
    '''
    def __init__(self, ni, no):
        super().__init__()
        self.conv = conv3x3(ni, no)
        self.bn = nn.BatchNorm2d(no)
    
    def forward(self, x):
        return self.bn(F.relu_(self.conv(x))) 
        
class Net(nn.Module):
    def __init__(self, layers, c):
        # layers: [3, 16, 32]
        # c: 17
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2)
        self.layers = nn.ModuleList([ConvBlock(layers[i], layers[i+1]) for i in range(len(layers) - 1)])
        self.linear = nn.Linear(layers[-1], c)
        
    def forward(self, x):
        # x: bs x 3 x 64 x 64
        x = self.conv1(x)
        for layer in self.layers: x = layer(x)
        # x: bs x channels x size x size
        x = F.adaptive_avg_pool2d(x, 1)
        # x: bs x channels
        x = x.view(x.size(0), -1)
        # x: bs x c
        x = self.linear(x)
        
        return torch.sigmoid(x)

In [35]:
net = Net([16, 32, 64], 16)

In [36]:
net

Net(
  (conv1): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (layers): ModuleList(
    (0): ConvBlock(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ConvBlock(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (linear): Linear(in_features=64, out_features=16, bias=True)
)

In [28]:
import torch

In [36]:
criterion = torch.nn.BCELoss()

out = torch.randn(10, 16)
out = torch.sigmoid(out)
target = torch.randint_like(out, low=0, high=2)

In [37]:
loss = criterion(out, target)

In [38]:
loss

tensor(0.8433)

In [None]:
criterion = torch