In [45]:
import warnings
from collections import namedtuple
from functools import partial
from typing import Optional, Tuple, List, Callable, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

import networkx as nx
import numpy as np

In [4]:
INPUT = 'input'
OUTPUT = 'output'
CONV3X3 = 'conv3x3-bn-relu'
CONV1X1 = 'conv1x1-bn-relu'
MAXPOOL3X3 = 'maxpool3x3'

new_labels = [INPUT, CONV1X1, CONV3X3, CONV3X3, CONV3X3, MAXPOOL3X3, OUTPUT]

matrix=np.array([[0, 1, 1, 1, 0, 1, 0],    # input layer
          [0, 0, 0, 0, 0, 0, 1],    # 1x1 conv
          [0, 0, 0, 0, 0, 0, 1],    # 3x3 conv
          [0, 0, 0, 0, 1, 0, 0],    # 5x5 conv (replaced by two 3x3's)
          [0, 0, 0, 0, 0, 0, 1],    # 5x5 conv (replaced by two 3x3's)
          [0, 0, 0, 0, 0, 0, 1],    # 3x3 max-pool
          [0, 0, 0, 0, 0, 0, 0]])

In [3]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

In [17]:
class Cell(nn.Module):
    def __init__(self, in_channels, matrix):
        super(Cell, self).__init__()
        self.G = nx.from_numpy_matrix(matrix, create_using=nx.DiGraph)
        for i in range(7):
            self.G.nodes[i]['label'] = i
            self.G.nodes[i]['op_label'] = new_labels[i]
            self.G.nodes[i]['incoming'] = [n for n in self.G.reverse().neighbors(i)]
            self.G.nodes[i]['outgoing'] = [n for n in self.G.neighbors(i)]
        proj_depth = int(in_channels/len(self.G.nodes[6]['incoming']))
        self.layers = nn.ModuleList()
        for n in range(7):
            node = self.G.nodes[n]
            modules = []
            op_label = node['op_label']
            if 0 in node['incoming']:
                modules.append(BasicConv2d(in_channels, proj_depth, kernel_size=1))
            if op_label == CONV1X1 and len(modules) == 0:
                modules.append(BasicConv2d(proj_depth, proj_depth, kernel_size=1))
            elif op_label == CONV3X3:
                modules.append(BasicConv2d(proj_depth, proj_depth, kernel_size=3, padding=1))
            elif op_label == MAXPOOL3X3:
                modules.append(nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True))
            if len(modules):
                self.layers.append(nn.Sequential(*modules))

    def forward(self, x):
        self.G.nodes[0]['output'] = x
        for n in range(1,6):
            input = 0
            for neighbor in self.G.nodes[n]['incoming']:
                input += self.G.nodes[neighbor]['output']
            self.G.nodes[n]['output'] = self.layers[n-1](input)
        outputs = torch.cat([self.G.nodes[n]['output'] for n in self.G.nodes[6]['incoming']], 1)
        return outputs

In [21]:
class CustomPool(nn.Module):
    def __init__(self, in_channels):
        super(CustomPool, self).__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv1x1 = BasicConv2d(in_channels, 2*in_channels, kernel_size=1)

    def forward(self, x):
        x = self.pool(x)
        x = self.conv1x1(x)
        return x

In [42]:
class CNN(nn.Module):
    def __init__(self, matrix):
        super(CNN, self).__init__()
        self.convstem = BasicConv2d(3, 128, kernel_size=3, padding=1) # 32x32
        
        self.stack1 = nn.ModuleList()
        for _ in range(3):
            self.stack1.append(Cell(128, matrix))
        self.stack1.append(CustomPool(128)) # 16x16

        self.stack2 = nn.ModuleList()
        for _ in range(3):
            self.stack2.append(Cell(256, matrix))
        self.stack2.append(CustomPool(256)) # 8X8

        self.stack3 = nn.ModuleList()
        for _ in range(3):
            self.stack3.append(Cell(512, matrix))
        self.global_pool = nn.MaxPool2d(8)

        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self,x):
        x = self.convstem(x)
        for i in range(4):
            x = self.stack1[i](x)
        for i in range(4):
            x = self.stack2[i](x)
        for i in range(3):
            x = self.stack3[i](x)
        x = self.global_pool(x).flatten(start_dim=1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [43]:
network = CNN(matrix)