# Define

In [1]:
import torch #基本モジュール
from torch.autograd import Variable #自動微分用
import torch.nn as nn #ネットワーク構築用
import torch.optim as optim #最適化関数
import torch.nn.functional as F #ネットワーク用の様々な関数
import torch.utils.data #データセット読み込み関連
import torchvision #画像関連
from torch import Tensor
from torchvision import datasets, models, transforms #画像用データセット諸々

import numpy as np
import argparse
import json
from logging.config import dictConfig
from logging import getLogger
import os
import time
from google.colab import files
import itertools

In [2]:
class Zero(nn.Module):
  def __init__(self, *args, **kwargs):
    super(Zero, self).__init__()

  def forward(self, input: Tensor) -> Tensor:
    return input * 0

In [3]:
class ReLUConvBN(nn.Module):

  def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
    super(ReLUConvBN, self).__init__()
    self.op = nn.Sequential(
      nn.ReLU(inplace=False),
      nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
      nn.BatchNorm2d(C_out, affine=affine)
    )

  def forward(self, x):
    return self.op(x)

In [36]:
class Edge(nn.Module):
  def __init__(self, operators, theta=None):
    super(Edge, self).__init__()
    self.operators = operators

  def forward(self, input: Tensor, theta: Tensor) -> Tensor:
    return sum(t * op(input) for t, op in zip(theta, self.operators))

In [48]:
# TODO : multiple input/output
# TODO : stride 2
# TODO : replace from += to .add_()
class Cell(nn.Module):
  def __init__(self, names, graph, cs, multi, reduce=False):
    super(Cell, self).__init__()
    (c_prev, c) = cs
    self.pre = ReLUConvBN(c_prev, c, 1, 1, 0, affine=False)
    self.reduce = reduce
    self.graph = graph
    self.ref = graph.edges()
    self.multi = multi
    self.edges = nn.ModuleList([Edge(self._make_modules(names, c, r)) for r in self.ref])

  def _make_modules(self, names, c, r):
    modules = []
    # stride = 2 if self.reduce and r[0] == 0 else 1
    stride = 1
    for name in names:
      modules += [OPS[name](c, stride, False)]
    return nn.ModuleList(modules)

  def forward(self, input: Tensor, theta: Tensor) -> Tensor:
    input = self.pre(input)
    args = {"requires_grad" : True, "device" : input.device}
    nodes = [torch.zeros(*list(input.shape), **args) for _ in range(self.graph.node_num - 1)]
    nodes[0] = input

    for idx, (s, e) in enumerate(self.ref): # zip ref & edges
      nodes[e] = nodes[e] + self.edges[idx](nodes[s], theta[idx])
    
    return torch.cat(nodes[-self.multi:], dim=1)

In [47]:
OPS = {
  'none' : lambda c, stride, affine: Zero(stride),
  'skip_connect' : lambda C, stride, affine: nn.Identity(), # if stride == 1 else FactorizedReduce(C, C, affine=affine),
  'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
  'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
  'conv_3x3' : lambda C, stride, affine: nn.Conv2d(C, C, 3, stride=stride, padding=1),
  'conv_5x5' : lambda C, stride, affine: nn.Conv2d(C, C, 5, stride=stride, padding=2),
}

      # if 'pool' in primitive:
      #   op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
CANDIDATE = [
  'conv_3x3',
  'conv_5x5',
  'avg_pool_3x3',
  'max_pool_3x3',
  'skip_connect',
  'none',
]

In [10]:
class Graph():
  def __init__(self, node_num : int):
    self.node_num = node_num
    self.input = 1
    self.output = 1
    self._graph = self._make_graph(node_num)
  
  def edges(self):
    return self._graph

  def ordered_edges(self):
    g = self._graph
    r = [i for i in range(self.node_num)][self.input:-self.output]
    return [[idx for idx, (s, e) in enumerate(g) if e == i] for i in r]

  def size(self):
    return len(self._graph)

  def _make_graph(self, num : int):
    l = [i for i in range(num-1)]
    return [c for c in itertools.combinations(l, 2)]

In [46]:
# TODO : multiple input/output
# TODO : random theta 1e-3 * randn
class CellNetwork(nn.Module):
  def __init__(self, depth = 4, node_num = 4, class_num = 10, multi = 3):
    super(CellNetwork, self).__init__()
    self.channels = channels
    self.depth = depth
    self.graph = Graph(node_num)
    print(self.graph.edges())
    self.init_modules(3, multi, class_num)
    self.onehot = False

  def is_reduce(self, idx):
    return idx % 3 == 2

  def init_modules(self, c, multi, class_num):
    # parameter & module
    normal_theta = torch.zeros(self.graph.size(), len(CANDIDATE), requires_grad=True)
    reduce_theta = torch.zeros(self.graph.size(), len(CANDIDATE), requires_grad=True)
    self.thetas = [normal_theta, reduce_theta]

    c_n = c * multi
    self.stem = nn.Sequential(
      nn.Conv2d(c, c_n, 3, padding=1, bias=False),
      nn.BatchNorm2d(c_n)
    )

    c_n1, c_n = c_n, c
    self.cells = nn.ModuleList()
    for i in range(self.depth):
      cell = Cell(CANDIDATE, self.graph, (c_n1, c_n), multi, reduce=self.is_reduce(i))
      self.cells += [cell]
      c_n1, c_n = multi * c_n, multi * c_n
      # c_n1 = multi * c_n
    
    self.pooling = nn.AdaptiveAvgPool2d(1)
    self.linear = nn.Linear(c_n1, class_num)

  def forward(self, input) -> Tensor:
    s = self.stem(input)

    for idx, cell in enumerate(self.cells):
      theta = self.thetas[1] if cell.reduce else self.thetas[0]
      weights = theta if self.onehot else F.softmax(theta, dim=-1)
      s = cell(s, weights)

    out = self.pooling(s)
    return self.linear(out.view(out.size(0), -1))
  
  def learn_theta(self, is_learn: bool):
    for theta in self.thetas:
      theta.requires_grad = is_learn

  def sampling(self):
    def _sampling(theta, graph=self.graph, ignore=CANDIDATE.index('none')):
      with torch.no_grad():
        for t in theta:
          t[ignore] = t.min(0).values
          max = t.max(0)
          t[:] = 0.0
          t[max.indices] = max.values

        for edges in graph.ordered_edges():
          values = [(e, theta[e].argmax(), theta[e].max(0).values) for e in edges]
          edges = sorted(values, key=lambda x: -x[2])[:graph.input]
          
          for (e, o, v) in values:
            theta[e][o] = 0.0
          for (e, o, v) in edges:
            theta[e][o] = 1.0
            

    for theta in self.thetas:
      _sampling(theta)
      
    self.learn_theta(False)
    self.onehot = True
    self.log()

  def log(self):
    print("Network")
    for theta in self.thetas:
      print(theta)

# Develop

In [15]:

# model = CellNetwork(depth = 1)
# input = torch.randn(16, 3, 32, 32)
# labels = torch.randn(16, 10).argmax(1)
# criterion = nn.CrossEntropyLoss()

# output = model(input)
# loss = criterion(output, labels)
# loss.backward()
# loss

In [16]:
# class Package(nn.Module):
#   def __init__(self, layers):
#     super().__init__()
#     self.layers = layers
#     self.model = nn.Sequential(*layers, nn.AdaptiveAvgPool2d(1), )
#     self.class_num = 10

#   def forward(self, input) -> Tensor:
#     x = self.model(input)
#     x = x.view(x.shape[0], -1)
#     x = nn.Linear(x.shape[1], self.class_num)(x)
#     return x

# operators = [nn.Identity(), None]
# layers = [Cell(operators)]
# model = Package(layers)
# input = torch.randn(16, 3, 32, 32)
# labels = torch.randn(16, 10).argmax(1)
# criterion = nn.CrossEntropyLoss()

# output = model(input)
# loss = criterion(output, labels)
# loss.backward()
# loss

# Unit Test

In [17]:
import unittest

def tensor_equal(x, y):
  return (torch.sum(x == y) == x.view(-1).shape[0]).item()

class TestEdge(unittest.TestCase):
  def test_id(self):
    input = torch.randn(1, 3, 32, 32)
    operators = [nn.Identity(), None]
    model = Edge(operators, theta=torch.tensor([1.0, 0.0]))
    output = model(input)
    self.assertEqual(tensor_equal(input, output), True)

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

E
ERROR: test_id (__main__.TestEdge)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<ipython-input-17-7c4fb3f2b357>", line 11, in test_id
    output = model(input)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'theta'

----------------------------------------------------------------------
Ran 1 test in 0.014s

FAILED (errors=1)


# Training

In [18]:
def train(model, device, train_loader, optimizer, optimizerB, criterion, logger, class_array):
    model.train()
    class_array = torch.LongTensor(class_array).to(device)
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        if optimizerB != None: optimizerB.zero_grad()
        output = model(data)
        reg = torch.zeros(output.shape[0], 10).to(device)
        reg.index_add_(1, class_array, output)
        loss = criterion(reg, target)
        loss.backward(retain_graph=True)
        optimizer.step()
        if optimizerB != None: optimizerB.step()
        
    return (None, loss.item())

In [19]:
def test(model, device, test_loader, criterion, logger, class_array):
    model.eval()
    test_loss = []
    correct = 0
    class_array = torch.tensor(class_array).to(device)
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss.append(criterion(output, target).item())
            pred = class_array[output.argmax(dim=1, keepdim=True)]
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    test_loss = np.mean(np.array(test_loss))
    accuracy = 100. * correct / len(test_loader.dataset)
    
    return (None, (test_loss, accuracy))

In [20]:
from argparse import Namespace
def dictspace(f):
  def inner(**kwds):
    return f(Namespace(**kwds))
  return inner

In [21]:
def load_dataset(train=2000, test=500):
  #画像の変形処理
  transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
  ])

  transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  ])

  #CIFAR-10のtrain, testsetのロード
  trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)
  testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                          download=True, transform=transform_test)
  
  trainset, _ = torch.utils.data.random_split(trainset, [train, 50000-train])
  testset, _ = torch.utils.data.random_split(testset, [test, 10000-test])
  return trainset, testset

In [22]:
class EarlyStopping:
  def __init__(self, dir="min", patent=5):
    self.list = []
    self.best = 0
    self.patent = patent
    self.count = 0
    self.order = dir == "max"

  def step(self, item):
    def _score(item):
      return item * (1 if self.order else -1)

    if len(self.list) == 0:
      self.best = _score(item)

    self.list.append(item)
    item = _score(item)
    if self.best < item:
      self.best = item
      self.count = 0
    else:
      self.count += 1

  def is_stop(self):
    return self.patent <= self.count

In [23]:
def SaveModel(name, model, dir="result"):
  path = os.path.join(dir, name + ".pt")
  if not os.path.exists(dir):
    os.mkdir(dir)

  torch.save(model.state_dict(), path)
  # print(os.path.join("/content", result_path))
  # files.download(os.path.join("/content", result_path))

In [50]:
def main(description, option=""):

  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda" if use_cuda else "cpu")
  print("device is %s" % device)

  trainset, testset = load_dataset()
  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

  class_array = [i for i in range(10)]

  @dictspace
  def learning(args):
    
    # instantiate
    model = CellNetwork(depth = args.depth, node_num = args.node)
    model.to(device)

    # ここは関数にする！　device init
    with torch.no_grad():
      for idx, theta in enumerate(model.thetas):
        model.thetas[idx] = theta.to(device)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=3e-4)
    optimizer_theta = optim.Adam(model.thetas, lr=args.lr_theta, betas=(0.5, 0.999), weight_decay=1e-3)
    criterion = nn.CrossEntropyLoss()

    if True:
      trainset, testset = load_dataset(train=args.train_size)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=False, **kwargs)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, **kwargs)
    
    time_sta = time.time()
    accuracy, loss = 0, 1e10
    model.learn_theta(False)
    for epoch in range(args.epochs):
      
      if epoch == args.switch:
        model.learn_theta(True)
      print(">>>> ", model.log())
      (_, loss_train) = train(model, device, trainloader, optimizer, optimizer_theta, criterion, None, class_array)
      (_, (loss_test, acc)) = test(model, device, testloader, criterion, None, class_array)
      
      print('epoch %d, acc %s' % (epoch, acc))

      accuracy, loss = acc, loss_test
      if time.time() - time_sta >= 60 * args.minutes:
        break 

    print("\naccuracy ", accuracy, end=", ")
    print("loss ", loss)
    SaveModel("cell", model)

    # relearning
    stop = EarlyStopping(patent=15)
    accuracy, loss = 0, 1e10
    model.sampling()
    remodel = CellNetwork(depth = args.depth, node_num = args.node).to(device)
    remodel.thetas = model.thetas
    remodel.sampling()
    model = remodel
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=3e-4)
    for epoch in range(args.epochs):
      
      print(">>>> ", model.log())

      (_, loss_train) = train(model, device, trainloader, optimizer, None, criterion, None, class_array)
      (_, (loss_test, acc)) = test(model, device, testloader, criterion, None, class_array)
      
      print('epoch %d, acc %s' % (epoch, acc))

      accuracy, loss = acc, loss_test
      stop.step(loss_test)
      if stop.is_stop():
        break

    print("\naccuracy ", accuracy, end=", ")
    print("loss ", loss)
    SaveModel("cell", model)

    return loss

  # learning(lr=0.020, lr_theta=3e-4, batch_size=64, train_size=20000, momentum=0.9, 
  #          epochs=50, switch=10, minutes=180, depth=6, node=4)
  learning(lr=0.0050, lr_theta=0.0005, batch_size=64, train_size=8000, momentum=0.9, 
           epochs=90, switch=30, minutes=180, depth=4, node=7)
  # learning(lr=0.001, lr_theta=0.0005, batch_size=16, train_size=4000, momentum=0.9, 
  #          epochs=10, switch=5, minutes=180, depth=4, node=5)

In [51]:
if __name__ == '__main__':
  main("", option=None)

[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
        [-0.0201,  0.0340, -0.0337,  0.0165, -0.0511,  0.0006],
        [-0.0117,  0.0187, -0.0173,  0.0004, -0.0275,  0.0169],
        [ 0.0062,  0.0222, -0.0208, -0.0105, -0.0360,  0.0231],
        [ 0.0212,  0.0411, -0.0280, -0.0520, -0.0387,  0.0396],
        [-0.0299,  0.0292, -0.0503,  0.0866, -0.0551, -0.0653],
        [-0.0286,  0.0243, -0.0378,  0.0576, -0.0437, -0.0383],
        [-0.0025,  0.0230, -0.0436,  0.0464, -0.0577, -0.0108],
        [ 0.0130,  0.0472, -0.0514, -0.0070, -0.0594,  0.0251],
        [-0.0221,  0.0341, -0.0458,  0.0580, -0.0530, -0.0590],
        [-0.0136,  0.0162, -0.0377,  0.0535, -0.0514, -0.0192],
        [ 0.0207,  0.0521, -0.0549, -0.0117, -0.0625,  0.0160],
        [-0.0192,  0.0317, -0.0339,  0.0367, -0.0443, -0.0402],
        [ 0.0151,  0.0595, -0.0478, -0.0205, -0.0540, -0.0075],
        [ 0.0123,  0.0674, -0.0533, -0.0393, -0.0571, -0.0342]],
       device='cuda:0', requires_grad=True)
>>>>  None
ep