In [4]:
import torch #基本モジュール
from torch.autograd import Variable #自動微分用
import torch.nn as nn #ネットワーク構築用
import torch.optim as optim #最適化関数
import torch.nn.functional as F #ネットワーク用の様々な関数
import torch.utils.data #データセット読み込み関連
import torchvision #画像関連
from torch import Tensor
from torchvision import datasets, models, transforms #画像用データセット諸々

import numpy as np
import argparse
import json
from logging.config import dictConfig
from logging import getLogger
import os
import time
from google.colab import files

In [3]:
# class MyNAS(nn.Module):
#     def __init__(self):
#         super(MyNAS, self).__init__()
#         self.theta = torch.randn(10, 10)
        

#     def forward(self, x):

#     def sampling(self, theta):
#       # onehot
#       a = torch.zeros(theta.shape)
#       return a

In [4]:

class Zero(nn.Module):
  def __init__(self, *args, **kwargs):
    super(Zero, self).__init__()

  def forward(self, input: Tensor) -> Tensor:
    return None

In [6]:
# operators : list
class Edge(nn.Module):
  def __init__(self, operators):
    super(Edge, self).__init__()
    self.operators = operators

    rand = torch.randn(len(operators), requires_grad=True) # pro dis
    self.theta = rand / torch.sum(rand)

  def forward(self, input: Tensor) -> Tensor:
    output = torch.zeros(input.shape, requires_grad=True)
    for (theta_i, operator) in zip(self.theta, self.operators):
      if operator == None:
        continue
      output = output +  theta_i * operator(input)
    
    return output

In [7]:
# node_num : int
# edges : list[Edge]
# inputNode : 0, 1
# outputNode : -1
class Cell(nn.Module):
  def __init__(self, operators):
    super(Cell, self).__init__()

    self.node_num = 3

    # self.ref = [(0, 1), (0, 2), (1, 2)]
    self.ref = [(0, 2)]
    self.edges = [Edge(operators) for _ in self.ref]

  def forward(self, input) -> Tensor:
    nodes = [torch.zeros(*list(input.shape), requires_grad=True) for _ in range(self.node_num)]
    nodes[0] = input
    # nodes[1] = ct1

    for idx, (inref, outref) in enumerate(self.ref): # zip ref & edges
      nodes[outref] = nodes[outref] + self.edges[idx](nodes[inref])
    
    return nodes[-1]

In [44]:
shape = [3, 2, 5]
a = torch.zeros(*shape, requires_grad=True)
b = torch.ones(*shape, requires_grad=True)
a = a + b

a[0][0][0].backward()

In [30]:
class MyModel(nn.Module):
  def __init__(self, ):
    super(MyModel, self).__init__()
    # operators = [nn.Conv2d(3, 3, 3, padding=1), nn.Identity(), None]
    operators = [nn.Identity()]
    self.cell = Cell(operators)
    self.depth = 4
    # self.input_num = 2
    # self.output_num = 1

  def forward(self, input) -> Tensor:
    shape = list(input.shape)
    inputs = []

    # increase channel num

    for i in range(self.depth):
      ct2 = inputs[i-2] if i-2 >= 0 else torch.zeros(*shape, requires_grad=True)
      ct1 = inputs[i-1] if i-1 >= 0 else torch.zeros(*shape, requires_grad=True)
      input = self.cell(ct2, ct1)
      inputs.append(input.clone()) # .retain_grad()

    linear = nn.Linear(3 * 32 * 32, 10)
    input = input.view(16, -1)
    input = linear(input)

    return input

In [12]:
class Package(nn.Module):
  def __init__(self, layers):
    super().__init__()
    self.layers = layers
    self.model = nn.Sequential(*layers, nn.AdaptiveAvgPool2d(1), )
    self.class_num = 10

  def forward(self, input) -> Tensor:
    x = self.model(input)
    x = x.view(x.shape[0], -1)
    x = nn.Linear(x.shape[1], self.class_num)(x)
    return x

operators = [nn.Identity()]
layers = [Cell(operators)]
model = Package(layers)
input = torch.randn(16, 3, 32, 32)
labels = torch.randn(16, 10).argmax(1)
criterion = nn.CrossEntropyLoss()

output = model(input)
loss = criterion(output, labels)
loss.backward()
loss

tensor(2.4969, grad_fn=<NllLossBackward>)

In [23]:
i = torch.randn(3, 2, 1)
p = torch.zeros(2, 1, 1, 3)
# i.view(2, -1, 1, 1, 1)
i = i.view_as(p)
i.shape

torch.Size([2, 1, 1, 3])

In [None]:
import unittest

def tensor_equal(x, y):
  return (torch.sum(x == y) == x.view(-1).shape[0]).item()

class TestEdge(unittest.TestCase):
  def test_id(self):
    input = torch.randn(1, 3, 32, 32)
    model = Edge(theta=torch.tensor([0.0, 1.0, 0.0]))
    output = model(input)
    self.assertEqual(tensor_equal(input, output), True)

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)