In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

In [7]:
a = nn.Conv2d(16, 16, 3, stride=1, padding=1)
b = nn.Conv2d(16, 16, 3, stride=1, padding=1)
c = nn.Conv2d(16, 16, 3, stride=1, padding=1)

g = nn.ModuleList([a,b,c])

In [10]:
dummy_input =torch.ones(24, 16, 32, 32)
dummy_output = torch.zeros(24, 16, 32, 32)

for module in g:
    dummy_output += module(dummy_input)
    

In [45]:
op_name = list(OPS.keys())
print(op_name)

['none', 'conv3', 'conv5']


In [33]:
OPS = {
    'none' : lambda C_in, C_out, stride: Zero(stride),
    'conv3' : lambda C_in, C_out, stride: nn.Conv2d(C_in, C_out, 3, stride, padding=1),
    'conv5' : lambda C_in, C_out, stride: nn.Conv2d(C_in, C_out, 5, stride, padding=2)
}

class Zero(nn.Module):

    def __init__(self, stride):
        super(Zero, self).__init__()
        self.stride = stride

    def forward(self, x):
        if self.stride == 1:
            return x.mul(0.)
        return x[:,:,::self.stride,::self.stride].mul(0.)

In [25]:
g = OPS['conv3'](16, 1)

In [58]:
class NAS201SearchCell(nn.Module):
    def __init__(
        self,
        C_in,
        C_out,
        stride,
        max_nodes,
        op_names):
        
        super(NAS201SearchCell, self).__init__()

        self.op_names = deepcopy(op_names)
        self.edges = nn.ModuleDict()
        self.max_nodes = max_nodes
        self.in_dim = C_in
        self.out_dim = C_out
        
        for i in range(1, max_nodes):
            for j in range(i):
                node_str = "{:}<-{:}".format(i, j)
                if j == 0:
                    xlists = [
                        OPS[op_name](C_in, C_out, stride)
                        for op_name in op_names
                    ]
                else:
                    xlists = [
                        OPS[op_name](C_in, C_out, 1)
                        for op_name in op_names
                    ]
                self.edges[node_str] = nn.ModuleList(xlists)
                
        self.edge_keys = sorted(list(self.edges.keys()))
        self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
        self.index2key = {i:key for i, key in enumerate(self.edge_keys)}
        self.num_edges = len(self.edges)
        self.register_buffer('alphas', nn.Parameter(1e-3 * torch.randn(self.num_edges, len(op_names))))
        

    def extra_repr(self):
        string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format(
            **self.__dict__
        )
        return string
    
    def get_arch_parameters(self):
        _arch_parameters = []
        
        
        for k, v in self.named_parameters():
            if k.endswith('alpha'):
                _arch_parameters.append(v)
        
        return _arch_parameters
    

    def forward(self, inputs):
        nodes = [inputs]
        for i in range(1, self.max_nodes):
            inter_nodes = []
            for j in range(i):
                node_str = "{:}<-{:}".format(i, j)
                weights = self.alphas[self.edge2index[node_str]]
                inter_nodes.append(
                    sum(
                        layer(nodes[j]) * w
                        for layer, w in zip(self.edges[node_str], weights)
                    )
                )
            nodes.append(sum(inter_nodes))
        return nodes[-1]


In [59]:
nas_cell = NAS201SearchCell(32, 32, 1, 5, op_name)

In [60]:
nas_cell.index2key

{0: '1<-0',
 1: '2<-0',
 2: '2<-1',
 3: '3<-0',
 4: '3<-1',
 5: '3<-2',
 6: '4<-0',
 7: '4<-1',
 8: '4<-2',
 9: '4<-3'}

In [31]:
max_len = 6
op_names = sorted(OPS.keys())
edges = nn.ModuleDict()
for i in range(1, max_len):
    for j in range(i):
        node_str = "{:}<-{:}".format(i, j)
        if j == 0:
            xlists = [
                OPS[op_name](16, 3)
                for op_name in op_names
            ]
        else:
            xlists = [
                OPS[op_name](16, 3)
                for op_name in op_names
            ]
        edges[node_str]= nn.ModuleList(xlists)

edges_keys = sorted(list(edges.keys()))
edge2index = {key:i for i, key in enumerate(edges_keys)}
num_edge = len(edges_keys)


In [39]:
weights = nn.Parameter(torch.ones([3, 4]))

In [41]:
weights[2]

tensor([1., 1., 1., 1.], grad_fn=<SelectBackward0>)

In [38]:
for i, weight in zip([1,2,3], weights):
    print(i, weight)

1 tensor(1., grad_fn=<UnbindBackward0>)
2 tensor(1., grad_fn=<UnbindBackward0>)
3 tensor(1., grad_fn=<UnbindBackward0>)
