In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

In [2]:
# code in this cell mostly from torchvision/models/resnet.py

def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                 dilation=dilation, padding=dilation, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class AtrousBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1,  downsample=None, dilation=1):
        super(AtrousBottleneck, self).__init__()
        self.dilation = dilation
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride, dilation)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        
        

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
    



### ASPP

In [3]:
class ASPP(nn.Module):
    def __init__(self):
        super(ASPP, self).__init__()
        self.conv1 = conv1x1(2048, 256)
        self.conv2 = conv3x3(2048, 256, dilation=12)
        self.conv3 = conv3x3(2048, 256, dilation=24)
        self.conv4 = conv3x3(2048, 256, dilation=36)
        
        # Operations for last feature map
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.conv = conv1x1(2048, 256)
        self.bn = nn.BatchNorm2d(256)
    
    def forward(self, x):
        # x is feature map
        out1 = F.relu(self.bn(self.conv1(x)))
        out2 = F.relu(self.bn(self.conv2(x)))
        out3 = F.relu(self.bn(self.conv3(x)))
        out4 = F.relu(self.bn(self.conv4(x)))

        out5 = F.relu(self.bn(self.conv(self.gap(x))))
        out5 = F.interpolate(out5, size=x.shape[-2:], mode="bilinear")                
                       
        out = torch.cat([out1,out2,out3,out4, out5])
        return out
        

In [4]:
class Multitask(nn.Module):
    def __init__(self):
        super(Multitask, self).__init__()
        rn101 = models.resnet101()
        self.truncated_rn101 = nn.Sequential(*rn101.children())[:-4]
        
        self.inplanes = 512
        print(self.inplanes)
        
        # TODO: replace these with atrous layers
        # Dilation choices of 2 and 4
        self.layer3 = self._make_layer(AtrousBottleneck, 256, 23, stride=1, dilation=2)
        self.layer4 = self._make_layer(AtrousBottleneck, 512, 3, stride=1, dilation=4)
        self.aspp = ASPP()
        
    # from torchvision.models.resnet.ResNet
    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        
        # If a stride=2 is passed to the block, input doesn't match the output
        # We need to downsample so we can add them together
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, dilation))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
        
    def forward(self, x):
        x = self.truncated_rn101(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.aspp(x)
        return x

### Shape test

In [5]:
model = Multitask()

512


In [6]:
model

Multitask(
  (truncated_rn101): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (downsample): Sequential(
          (0): Conv2d(64, 

In [7]:
test = torch.zeros(size=(2,3,256,512))

In [8]:
result = model.forward(test)

  "See the documentation of nn.Upsample for details.".format(mode))


In [9]:
result.shape

torch.Size([10, 256, 32, 64])

### CIFAR10
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

In [10]:
class CIFAR(nn.Module):
    def __init__(self):
        super(CIFAR, self).__init__()
        self.Multitask = Multitask()
        self.Pool = nn.AdaptiveAvgPool2d((1,1))
        self.Linear = nn.Linear(in_features=512, out_features=10)
   
    def forward(self, x):
        out = self.Multitask(x)
        out = self.Pool(out)
        out = out.view(-1, 512)
        out = self.Linear(out)
        return out
        

In [11]:
CIFAR10model = CIFAR()
#CIFAR10model = nn.Sequential(Multitask(), nn.AdaptiveAvgPool2d((1, 1)), nn.Linear(in_features=2048, out_features=10))
#truncatedCIFAR10model = nn.Sequential(Multitask(), nn.AdaptiveAvgPool2d((1, 1)), nn.View((-1, 512))

512


TODO: check value of in_features for Linear layer

In [16]:
CIFAR10test = torch.zeros(size=(2,3,256,256))

In [17]:
CIFAR10result = CIFAR10model.forward(CIFAR10test)
#truncatedCIFAR10result = truncatedCIFAR10model.forward(CIFAR10test)

In [18]:
CIFAR10result.shape
#truncatedCIFAR10result.shape

torch.Size([5, 10])

In [19]:
CIFAR10model

CIFAR(
  (Multitask): Multitask(
    (truncated_rn101): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          

In [20]:
import torchvision
import torchvision.transforms as transforms

# TODO: fix image resizing
transform = transforms.Compose(
    [transforms.Resize((256,256)), transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [21]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(CIFAR10model.parameters(), lr=0.001, momentum=0.9)

In [22]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = CIFAR10model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        #if i % 20 == 19:    # print every 2000 mini-batches
        print('[%d, %5d] loss: %.3f' %
              (epoch + 1, i + 1, running_loss))
        running_loss = 0.0

print('Finished Training')

ValueError: Expected input batch_size (10) to match target batch_size (4).