In [7]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

In [8]:
class CNN(nn.Module):
    def __init__(self, n_classes):
        super(CNN, self).__init__()
        # conv layers: (in_channel size, out_channels size, kernel_size, stride, padding)
        self.conv1 = nn.Conv2d(1, 32, 5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(32, 16, 5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(16, 8, 5, stride=1, padding=2)

        # max pooling (kernel_size, stride)
        self.pool = nn.MaxPool2d(2, 2)

        # fully conected layers:
        self.layer1 = nn.Linear(4*4*8, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, n_classes)

    def forward(self, x, training=True):
        # the autoencoder has 3 con layers and 3 deconv layers (transposed conv). All layers but the last have ReLu
        # activation function
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(-1, 4 * 4 * 8)
        x = F.relu(self.layer1(x))
        x = F.dropout(x, 0.5, training=training)
        x = F.relu(self.layer2(x))
        x = F.dropout(x, 0.5, training=training)
        x = self.layer3(x)
        return x

    def predict(self, x):
        # a function to predict the labels of a batch of inputs
        x = F.softmax(self.forward(x, training=False))
        return x

    def accuracy(self, x, y):
        # a function to calculate the accuracy of label prediction for a batch of inputs
        #   x: a batch of inputs
        #   y: the true labels associated with x
        prediction = self.predict(x)
        maxs, indices = torch.max(prediction, 1)
        acc = 100 * torch.sum(torch.eq(indices.float(), y.float()).float())/y.size()[0]
        return acc.cpu().data[0]


In [9]:
model = CNN(10)
model.cuda()

CNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(32, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(16, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (layer1): Linear(in_features=128, out_features=64, bias=True)
  (layer2): Linear(in_features=64, out_features=64, bias=True)
  (layer3): Linear(in_features=64, out_features=10, bias=True)
)

In [11]:
state = torch.load('cnn.pth')
model.load_state_dict(state)

<All keys matched successfully>

CNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(32, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(16, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (layer1): Linear(in_features=128, out_features=64, bias=True)
  (layer2): Linear(in_features=64, out_features=64, bias=True)
  (layer3): Linear(in_features=64, out_features=10, bias=True)
)

In [36]:
modules = []
for name, module in model.named_modules():
    modules.append(name)
print(modules) 

['', 'conv1', 'conv2', 'conv3', 'pool', 'layer1', 'layer2', 'layer3']


In [39]:
getattr(model, modules[1]).weight

Parameter containing:
tensor([[[[ 1.9111e-03, -7.5466e-02, -8.3581e-02, -6.7995e-02, -2.1533e-01],
          [ 1.0316e-01,  2.1668e-01, -1.2940e-01,  2.4202e-01,  2.8774e-02],
          [ 2.1586e-01, -3.9543e-02,  2.9735e-01,  2.2145e-01,  1.7842e-01],
          [ 1.8348e-01,  6.5618e-02,  3.5390e-01,  8.3087e-02,  1.2447e-01],
          [ 1.6329e-01,  1.1724e-01,  2.5607e-01,  7.4599e-02,  1.0450e-01]]],


        [[[ 8.3825e-02,  5.9209e-02,  9.6223e-02,  2.2185e-01, -7.3424e-02],
          [-1.4050e-01,  5.5635e-02,  2.8236e-01, -1.1224e-01,  1.8013e-01],
          [-3.0268e-02,  1.7202e-01,  1.6379e-02,  2.7759e-01, -7.8112e-02],
          [ 7.4248e-02,  2.9403e-01,  9.7572e-02,  5.8801e-02,  1.0877e-01],
          [ 2.5658e-01,  2.5329e-02,  2.1847e-01,  1.8750e-01, -7.9955e-02]]],


        [[[ 1.1861e-01, -8.9525e-03, -1.6314e-01, -2.7319e-01,  1.1180e-01],
          [ 1.4242e-01, -1.8169e-01,  1.3020e-01, -1.9418e-01,  7.4939e-02],
          [-7.9088e-02,  4.6304e-02,  5.7971e-