In [2]:
"""
Simple intro. to pytorch's NN layers
"""

import torch
from torch import nn # Neural network layers
from torch.autograd import Variable

# create a 2x3x10x10 input, i.e., 3 channels, 10x10 spatial dim. and batch size of 2
x = Variable(torch.rand(2,3,10,10))

# create a 2D conv. layer with 1 output channel a 3x3 filter size and a shift of 1
conv_layer = nn.Conv2d(3,1,3,1)

# push input through conv. layer
out = conv_layer(x)
print out.size()

# create a view of the previous output as a 2x64 (64=1x8x8) tensor
out = out.view(out.size(0),64)

# add a linear layer 64->1 and push data through this
fc = nn.Linear(64,1)
fc(out)

torch.Size([2, 1, 8, 8])


Variable containing:
 0.1518
 0.1643
[torch.FloatTensor of size 2x1]

In [4]:
# spatial max. pooling with a square window of size 3x3, shift by stride=2
pooling_layer = nn.MaxPool2d(2, stride=2)

x = Variable(torch.randn(1, 3, 10, 10))
output = pooling_layer(x)
print output.size()

torch.Size([1, 3, 5, 5])


In [7]:
# simple AlexNet (without classifier at end)
net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2))

In [8]:
# create random input 3x128x128 and a batch size of 10
x = Variable(torch.randn(10, 3, 128, 128))
# push data through net (forward pass)
out = net(x)
print out.size()

# view output as 10x(256x3x3) tensor and push through a linear layer that
# maps the input to 100 outputs
out = out.view(out.size(0),256*3*3)
lin = nn.Linear(256*3*3,100)
print lin(out).size()

torch.Size([10, 256, 3, 3])
torch.Size([10, 100])
