In [1]:
import os
import sys
import time
import glob
import json
import pytz
import numpy as np
import logging
import seaborn as sns
import argparse
import datetime
import networkx as nx
import matplotlib.pyplot as plt

import torch #基本モジュール
from torch.autograd import Variable #自動微分用
import torch.nn as nn #ネットワーク構築用
import torch.optim as optim #最適化関数
import torch.nn.functional as F #ネットワーク用の様々な関数
import torch.utils.data #データセット読み込み関連
import torchvision #画像関連
from torchvision import datasets, models, transforms #画像用データセット諸々
import torch.backends.cudnn as cudnn

In [2]:
SERVER = False

In [3]:
if SERVER:
  from tqdm import tqdm

In [4]:
if not SERVER:
  from google.colab import output
  from tqdm.notebook import tqdm
  import graphviz

In [5]:
if not SERVER:
  from google.colab import drive
  drive.mount('/content/drive')

Mounted at /content/drive


# utils

## other

In [6]:
def argspace(*funcs, **kwds):
  def deco(f):
    def inner(**ikwds):
      for g in funcs:
        kwds.update(g())
      kwds.update(ikwds)
      return f(argparse.Namespace(**kwds))
    return inner
  return deco

In [7]:
def set_seed(seed):
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)

In [8]:
class Singleton(object):
  def __new__(cls, *args, **kargs):
    if not hasattr(cls, "_instance"):
      cls._instance = super(Singleton, cls).__new__(cls)
    return cls._instance

In [9]:
class Experiment(Singleton):
  def __init__(self):
    if not hasattr(self, 'func'):
      self.func = {}
  
  def _store(self, key, f, order):
    self.func.setdefault(key, {})
    self.func[key].update({f.__name__:(order, f)})

  def __call__(self, key):
    def f(*args, **kwds):
      if key not in self.func: return None

      funcs = sorted(self.func[key].values(), key=lambda x: x[0])
      return [g(*args, **kwds) for _, g in funcs]
    return f

  def event(*key, order=0):
    def d(f):
      for k in key:
        Experiment()._store(k, f, order)
      def inner(*args, **kwds):
        return f(*args, **kwds)
      return inner
    return d
  
  def reset():
    Experiment().func = {}

## metrics

In [65]:
def accuracy(output, target, topk=(1,)):
  maxk = max(topk)
  batch_size = target.size(0)

  _, pred = output.topk(maxk, 1, True, True)
  pred = pred.t()
  correct = pred.eq(target.view(1, -1).expand_as(pred))

  res = []
  for k in topk:
    correct_k = correct[:k].reshape(-1).float().sum(0)
    res.append(correct_k.mul_(100.0/batch_size))
  return res

In [11]:
class AvgrageMeter(object):

  def __init__(self):
    self.reset()

  def reset(self):
    self.avg = 0
    self.sum = 0
    self.cnt = 0

  def update(self, val, n=1):
    self.sum += val * n
    self.cnt += n
    self.avg = self.sum / self.cnt

In [12]:
class Metrics():
  def __init__(self, model, dataloader):
    self._run(model, dataloader)
    self.epsilon = 1e-7
  
  @torch.no_grad()
  def _run(self, model, dataloader):
    data, target = next(iter(dataloader))
    device =  next(model.parameters()).device
    print(device)
    num = target.max().long() + 1
    cm = torch.zeros(num, num).to(device)

    for i, (data, target) in enumerate(dataloader):
      data = data.to(device)
      target = target.to(device)
      outputs = model(data)
      _, preds = torch.max(outputs, 1)
      for t, p in zip(target.view(-1), preds.view(-1)):
        cm[t.long(), p.long()] += 1

    self.matrix = cm
    self.dim = num
    self.sum0 = self.matrix.sum(0)
    self.sum1 = self.matrix.sum(1)
    self.sum = self.matrix.sum()

  def confusion_matrix(self):
    return self.matrix

  def TP(self, index):
    return self.matrix[index][index]

  def FN(self, index):
    return self.sum1[index] - self.TP(index)

  def FP(self, index):
    return self.sum0[index] - self.TP(index)

  def TN(self, index):
    return self.sum - self.TP(index) - self.FN(index) - self.FP(index)

  def _sum(self, F):
    return sum(F(i) for i in range(self.dim))
    
  def _micro(self, F, G):
    return self._sum(F) / (self._sum(F) + self._sum(G) + self.epsilon)

  def _macro(self, F, G):
    return sum(F(i) / (F(i) + G(i) + self.epsilon) for i in range(self.dim)) / self.dim

  def _switch(self, F, G, micro):
    return (self._micro(F, G) if micro else self._macro(F, G))

  def accuracy(self, micro=True):
    return (self._sum(self.TP) / self.sum if micro else 
            (sum(self.TP(i) / self.sum1 for i in range(self.dim)) / self.dim).mean())

  def precision(self, micro=True):
    return self._switch(self.TP, self.FP, micro)

  def recall(self, micro=True):
    return self._switch(self.TP, self.FN, micro)

  def specificity(self, micro=True):
    return self._switch(self.TN, self.FP, micro)

  def f_measure(self, micro=True):
    p, r = self.precision(micro), self.recall(micro)
    return 2 * p * r / (p + r + self.epsilon)

  def print(self):
    print(self.confusion_matrix())
    print("accuracy ", self.accuracy(), self.accuracy(micro=False))
    print("precision ", self.precision(), self.precision(micro=False))
    print("recall ", self.recall(), self.recall(micro=False))
    print("specificity ", self.specificity(), self.specificity(micro=False))
    print("f_measure ", self.f_measure(), self.f_measure(micro=False))

In [13]:
def count_param(model : nn.Module):
  param = model.named_parameters()
  itr = [np.prod(v.size()) for name, v in param if "auxiliary" not in name]
  return np.sum(itr)

In [14]:
from typing import Union
from collections import OrderedDict
@torch.no_grad()
def square_error(m : Union[nn.Module, OrderedDict], n : Union[nn.Module, OrderedDict]):
  m = m.parameters() if isinstance(m, nn.Module) else m.values()
  n = n.parameters() if isinstance(n, nn.Module) else n.values()
  s = [torch.sum((x - y) * (x - y)) for x, y in zip(m, n)]
  return sum(s)

## save

In [15]:
def path_with_time(path : str) -> str:
  dt_now = datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
  return '{}-{}'.format(path, dt_now.strftime('%Y-%m-%d_%H-%M-%S'))

In [16]:
def create_exp_dir(path, scripts_to_save=None):
  if not os.path.exists(path):
    os.mkdir(path)
  print('Experiment dir : {}'.format(path))

  if scripts_to_save is not None:
    os.mkdir(os.path.join(path, 'scripts'))
    for script in scripts_to_save:
      dst_file = os.path.join(path, 'scripts', os.path.basename(script))
      shutil.copyfile(script, dst_file)

In [17]:
def init_logging(save_path):
  class Formatter(logging.Formatter):
      """override logging.Formatter to use an aware datetime object"""
      def converter(self, timestamp):
          dt = datetime.datetime.fromtimestamp(timestamp)
          tzinfo = pytz.timezone('Asia/Tokyo')
          return tzinfo.localize(dt)
          
      def formatTime(self, record, datefmt=None):
          dt = self.converter(record.created)
          if datefmt:
              s = dt.strftime(datefmt)
          else:
              try:
                  s = dt.isoformat(timespec='milliseconds')
              except TypeError:
                  s = dt.isoformat()
          return s

  log_format = '%(asctime)s %(message)s'
  logging.basicConfig(stream=sys.stdout, level=logging.INFO,
      format=log_format, datefmt='%m/%d %I:%M:%S %p')
  fh = logging.FileHandler(os.path.join(save_path, 'log.txt'))
  fh.setFormatter(Formatter(log_format))
  logging.getLogger().addHandler(fh)

In [18]:
def save_dir(dir : str, drivepath = './drive/My Drive/ml'):
  if SERVER: return
  if not dir: return

  import subprocess
  res = subprocess.run(["cp", "-r", "./" + dir, drivepath], stdout=subprocess.PIPE)
  sys.stdout.write(res.stdout)

In [19]:
class Store():
  def __init__(self, dir="result", name="log", fig=[]):
    self.dict = {}
    self.dir = dir
    self.name = name
    self.fig = fig
  
  def add(self, name, value):
    if not name in self.dict:
      self.dict[name] = []
    self.dict[name].append(value)
  
  def apply(self, name, func):
    if not name in self.dict:
      self.dict[name] = []
    return func(self.dict[name])

  def update(self, store : 'Store'):
    self.dict.update(store.dict)

  def save(self):
    self.save_log()
    for metrix, x, y in self.fig:
      self.save_fig(metrix, x, y)

  def save_log(self, name=None):
    name = name if name else self.name
    path = os.path.join(self.dir, name + ".txt")
    with open(path, mode='w') as f:
      f.write("%s" % self.dict)

  def save_fig(self, metrix, xlabel, ylabel, show=True):
    fig = plt.figure()

    if type(metrix) is str:
      times = len(self.dict[metrix])
      plt.plot(np.arange(times), self.dict[metrix])
    else :
      times = len(self.dict[metrix[0]])
      for m in metrix:
        plt.plot(np.arange(times), self.dict[m], label=m)
      metrix = "_".join(metrix)
    
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    if show and not SERVER: plt.show()
    fig.savefig(os.path.join(self.dir, "%s_%d.png" % (metrix, times)))

  def __repr__(self):
    return "store in %s" % self.dict

In [20]:
def render_graph(graph, path):
  if SERVER: return
  format = dict(
      format='png', 
      edge_attr=dict(fontsize='20', fontname="times"),
      node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
      engine='dot' # circo, dot, fdp, neato, osage, sfdp, twopi
  )

  dg = graphviz.Digraph(**format)

  dg.attr('node', fillcolor='dodgerblue4', fontcolor='white', fontsize='15') # coral, 
  for node in graph.nodes():
    attr = graph.nodes[node]
    label = attr['name'] if 'name' in attr else str(node)
    label += '\n(%s, %d, %d)' % (attr['channel'], 32 / attr['stride'], 32 / attr['stride'])
    dg.node(str(node), label=label)

  for (i, j) in graph.edges():
    attr = graph.edges[i, j]
    label = attr['module']
    label = ""
    style = 'bold' if attr['module'] == 'forward' else 'dashed'
    dg.edge(str(i), str(j), label=label, style=style)

  dg.render(path)
  return dg

In [21]:
def save_heatmap(data : torch.tensor, path, format='1.2f'):
  plt.figure()
  data = data.detach().cpu().clone().numpy()
  sns.heatmap(data, annot=True, fmt=format)
  plt.savefig(path)
  plt.close('all')

# dataset

In [22]:
class Cutout(object):
  def __init__(self, length):
    self.length = length

  def __call__(self, img):
    h, w = img.size(1), img.size(2)
    mask = np.ones((h, w), np.float32)
    y = np.random.randint(h)
    x = np.random.randint(w)

    y1 = np.clip(y - self.length // 2, 0, h)
    y2 = np.clip(y + self.length // 2, 0, h)
    x1 = np.clip(x - self.length // 2, 0, w)
    x2 = np.clip(x + self.length // 2, 0, w)

    mask[y1: y2, x1: x2] = 0.
    mask = torch.from_numpy(mask)
    mask = mask.expand_as(img)
    img *= mask
    return img

In [23]:
def load_dataset(train=2000, test=500, valid=0, cutout=False, cutout_length=16):
  #画像の変形処理
  transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
  ])

  if cutout:
    transform.transforms.append(Cutout(cutout_length))

  transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  ])

  dataset = torchvision.datasets.CIFAR10
  kwargs = {"root" : "./data", "download" : True}

  #CIFAR-10のtrain, testsetのロード
  trainset = dataset(train=True, transform=transform, **kwargs)
  testset = dataset(train=False, transform=transform_test, **kwargs)
  
  trainset, validset, _ = torch.utils.data.random_split(trainset, [train, valid, 50000-train-valid])
  testset, _ = torch.utils.data.random_split(testset, [test, 10000-test])
  return argparse.Namespace(train=trainset, test=testset, valid=validset)

In [24]:
def load_dataloader(args):
  kwargs = {'num_workers': 2, 'pin_memory': True} if args.use_cuda else {}
  args.valid_size = args.valid_size if args.valid_size else 0
  sets = load_dataset(train=args.train_size, test=args.test_size, valid=args.valid_size, cutout=args.cutout, cutout_length=args.cutout_length)
  trainloader = torch.utils.data.DataLoader(sets.train, batch_size=args.batch_size, shuffle=True, **kwargs)
  validloader = torch.utils.data.DataLoader(sets.valid, batch_size=args.batch_size, shuffle=True, **kwargs) if args.valid_size else None
  testloader = torch.utils.data.DataLoader(sets.test, batch_size=args.batch_size, shuffle=False, **kwargs)
  return argparse.Namespace(train=trainloader, test=testloader, valid=validloader)

# model

## sampler

In [25]:
class ArchitectureSampler():
  def __call__(self, graph : nx.DiGraph, alpha : torch.Tensor) -> nx.DiGraph:
    return self.graph(graph, alpha)

In [26]:
class MaxSampler(ArchitectureSampler):
  def graph(self, graph, alpha):
    G = nx.DiGraph(graph)
    n = G.number_of_nodes()

    for j in range(1, n):
      edges = [(i, j) for i in G.predecessors(j)]
      alphas = [alpha[i, j].item() for i, j in edges]
      edge_num = round(sum(alphas))
      disable = sorted(zip(edges, alphas), key=lambda x: x[-1])[:-edge_num]
      G.remove_edges_from([i for i, _ in disable])

    return G

In [27]:
class EdgewiseSampler(ArchitectureSampler):
  def graph(self, graph, alpha):
    G = nx.DiGraph(graph)
    n = G.number_of_nodes()

    for j in range(1, n):
      edges = [(i, j) for i in G.predecessors(j)]
      alphas = [alpha[i, j].item() for i, j in edges]
      disable = [(e, a) for e, a in zip(edges, alphas) if round(a) < 1]
      G.remove_edges_from([i for i, _ in disable])

    return G

In [28]:
class ForwardSampler(ArchitectureSampler):
  def graph(self, graph, alpha):
    G = nx.DiGraph(graph)
    n = G.number_of_nodes()

    for j in range(1, n):
      edges = [(i, j) for i in G.predecessors(j) if not i + 1 == j]
      G.remove_edges_from(edges)

    return G

In [29]:
class RandomSampler(ArchitectureSampler):
  def __init__(self, shortcut_num=0):
    self.num = shortcut_num

  def graph(self, graph, alpha):
    G = nx.DiGraph(graph)

    edges = [(e, f) for e, f in G.edges() if not e + 1 == f]
    np.random.shuffle(edges)

    for e, f in edges[:len(edges)-self.num]:
      G.remove_edge(e, f)

    return G

In [30]:
class StrideCutSampler(ArchitectureSampler):
  def __init__(self, stride_max=8):
    assert stride_max >= 1
    self.stride = stride_max

  def graph(self, graph, alpha):
    G = nx.DiGraph(graph)
    n = G.number_of_nodes()

    for j in range(1, n):
      edges = [(i, j) for i in G.predecessors(j) if G.edges[i, j]['stride'] > self.stride]
      G.remove_edges_from(edges)

    return G

In [31]:
# args = {'gene':'VGG19', 'stride_max':2}
# model = load_model(dir="", **args)
# g = model.graph
# sampler = RandomSampler(10)
# h = model.sampled_graph(sampler)
# render_graph(h, 'graph')
# # nx.graph_edit_distance(g, h)

## component

In [32]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

In [33]:
class FactorizedReduce(nn.Module):
  def __init__(self, channel_in, channel_out, stride, affine=True):
    super(FactorizedReduce, self).__init__()
    assert channel_out % stride == 0
    
    self.convs = nn.ModuleList([
      nn.Conv2d(channel_in, channel_out // stride, 1, stride=stride, padding=0, bias=False)
      for _ in range(stride)
    ])
    self.bn = nn.BatchNorm2d(channel_out, affine=affine)

  def forward(self, x):
    # strideの偶奇による情報ロスを防ぐ
    out = torch.cat([conv(x[:,:,i:,i:]) for i, conv in enumerate(self.convs)], dim=1)
    out = self.bn(out)
    return out

In [34]:
class Shortcut(nn.Module):
  def __init__(self, in_channel, out_channel, stride):
    super(Shortcut, self).__init__()
    self.f = self._shortcut(in_channel, out_channel, stride)

  def forward(self, x):
    return self.f(x)

  def _shortcut(self, channel_in, channel_out, stride):
    if stride > 1:
      return FactorizedReduce(channel_in, channel_out, stride)
    elif channel_in != channel_out:
      return nn.Conv2d(channel_in, channel_out, 
                       kernel_size=1, stride=stride, padding=0)
    else:
      return lambda x: x

In [35]:
def drop_path(x, drop_prob):
  if drop_prob > 0.:
    keep_prob = 1. - drop_prob
    b = torch.ones(x.size(0), 1, 1, 1, device=x.device) * keep_prob
    mask = Variable(torch.bernoulli(b))
    x = x / keep_prob * mask
  return x

In [36]:
# x = torch.randn(4, 3, 8, 8).to(torch.device('cuda'))
# keep_prob = 0.5
# torch.bernoulli(torch.ones(x.size(0), 1, 1, 1, device=x.device) * keep_prob)

In [37]:
class Block(nn.Module):
  def __init__(self, graph, index):
    super(Block, self).__init__()
    node = graph.nodes[index]
    edges = [(i, index, graph.edges[i, index]) for i in graph.predecessors(index)]

    self.index = index
    self.indices = [i for i, _, _ in edges]
    self.edges = nn.ModuleList([self._build_module(s) for i, j, s in edges])

    process = [nn.ReLU(inplace=True)]
    if node['pool']: process += [nn.MaxPool2d(kernel_size=2, stride=2)]
    self.post_process = nn.Sequential(*process)

  def _build_module(self, setting):
    module = setting['module']
    in_channel, out_channel = setting['channel']
    stride = setting['stride']
    if module == 'forward':
      return nn.Sequential(
          nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
          nn.BatchNorm2d(out_channel)
          )
    elif module == 'shortcut':
      return Shortcut(in_channel, out_channel, stride)
    else:
      raise ValueError("module name")

  def forward(self, inputs, alpha, drop_prob=0.):
    if self.training and drop_prob > 0.:
      # f = labmda x: drop_path(F(x), drop_prob)
      output = sum(alpha[i] * drop_path(F(inputs[i]), drop_prob) for i, F in zip(self.indices, self.edges))
    else:
      output = sum(alpha[i] * F(inputs[i]) for i, F in zip(self.indices, self.edges))
    return self.post_process(output)

## network

In [78]:
# 前提条件 : make graph, modules(Block, pool, ...)
# 拘束条件 : alpha sofmax
class Network(nn.Module):
  def __init__(self, gene, graph=None, preprocess : ArchitectureSampler=None):
    super(Network, self).__init__()
    self.gene = gene
    self.evaluate = bool(graph)
    self.graph = graph if graph else self._make_graph(gene)
    self.graph = preprocess(self.graph, None) if preprocess else self.graph

    self.blocks = nn.ModuleList(self._make_blocks(self.graph))
    self.pool = nn.AvgPool2d(kernel_size=1, stride=1)
    self.classifier = nn.Linear(512, 10)
    self.drop_path_prob = 0.

  def _make_graph(self, gene, color_channel=3):
    
    def _decode_gene(gene):
      ch, st = [], []
      for g in gene:
        if g == 'M':
          st[-1] *= 2
        else :
          ch += [g]
          st += [1]
      return ch, st

    def __pi(array):
      r = []
      s = 1
      for q in array:
        s *= q
        r += [s]
      return r

    channel, stride = _decode_gene(gene)
    channel = [color_channel] + channel
    stride = [1] + stride
    s_stride = __pi(stride)

    n = len(channel)
    nodes = [(i, {'channel':channel[i], 'stride':s_stride[i], 'pool':stride[i]>1}) for i in range(n)]
    nodes[0][-1].update({'name':'input'})
    edges = [(i, j, {}) for i in range(n) for j in range(n) if i < j and not (i == 0 and j > 1)]
    for (i, j, d) in edges:
      d.update({
        'module' : 'forward' if i + 1 == j else 'shortcut', 
        'channel' : (nodes[i][-1]['channel'], nodes[j][-1]['channel']),
        'stride' : int(nodes[j-1][-1]['stride'] / nodes[i][-1]['stride'])
      })

    G = nx.DiGraph()
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)
    return G

  def _make_blocks(self, graph):
    return [Block(graph, i) for i in graph.nodes() if i > 0]

  def init_alpha(self, device):
    def _init_alpha(node_num, device, delta=1e-3):
      noise = delta * torch.randn(node_num, node_num, device=device)
      alpha = noise.clone().detach().requires_grad_(True)
      return [alpha]

    def _mask(node_num, device, name):
      mask = torch.zeros(node_num, node_num, device=device)

      for i, j in self.graph.edges():
        op = self.graph.edges[i, j]['module']
        if not op == name: continue

        mask[i, j] = 1

      return mask.t() > 0

    n = self.graph.number_of_nodes()
    self.alphas = _init_alpha(n, device)
    self.mask_s = _mask(n, device, 'shortcut')
    self.mask_f = _mask(n, device, 'forward')

    # self.alphas += [torch.ones(n, device=device, requires_grad=True)]
    self.alphas += [torch.zeros(n, device=device, requires_grad=True)]

    return self

  def normalized_alpha(self):
    alpha = torch.zeros_like(self.alphas[0])
    if self.evaluate:
      for i, j in self.graph.edges():
        alpha[j, i] = 1.0
    else:
      alpha[self.mask_f] = 1.0
      for a, raw, mask, b in zip(alpha, self.alphas[0], self.mask_s, self.normalized_beta()):
        print("a", a[mask])
        print("b", b * F.softmax(raw[mask], dim=0))
        print("l", len(a[mask]))
        if len(a[mask]) <= 0: continue
        print("???", len(a[mask]))
        a[mask] = b * F.softmax(raw[mask], dim=0)
    return alpha

  def normalized_beta(self):
    # f = lambda x : math.exp(x - 1) if x < 1 else math.log(x) + 1
    # return self.beta().apply_(f)
    x = self.beta()
    m = x>1
    beta = torch.zeros_like(x)
    beta[m] = torch.log(x[m]) + 1
    beta[~m] = torch.exp(x[~m] - 1)
    return beta

  def beta(self):
    return self.alphas[1]

  def sampled_graph(self, sampler : ArchitectureSampler):
    return sampler(self.graph, self.matrix_alpha())

  @torch.no_grad()
  def matrix_alpha(self, normalize=True):
    return (self.normalized_alpha() if normalize else self.alphas[0]).t()


  def forward(self, x):
    assert self.evaluate or self.drop_path_prob <= 0.
    state = [x]
    alpha = self.normalized_alpha()
    # alpha = self.alphas[0]

    for block in self.blocks:
      x = block(state, alpha[block.index], self.drop_path_prob)
      state += [x]

    out = self.pool(x)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out

## architecture

In [39]:
class Architect():
  def __init__(self, valid_loader, model, criterion, lr, device):
    self.valid_loader = valid_loader
    self.model = model
    self.criterion = criterion
    self.optimizer = optim.Adam(model.alphas, lr=lr, betas=(0.5, 0.999), weight_decay=1e-3)
    self.device = device
    self.train = True

  def step(self):
    if not self.valid_loader: return
    if not self.train: return
    
    data_v, target_v = next(iter(self.valid_loader))
    data_v, target_v = data_v.to(self.device), target_v.to(self.device)

    self.optimizer.zero_grad()
    output = self.model(data_v)
    loss = self.criterion(output, target_v)
    loss.backward()
    self.optimizer.step()

# learning

## events

In [40]:
@Experiment.event('setup')
def setup(args):
  args.save = path_with_time(args.save)
  create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
  init_logging(args.save)
  logging.info('kwargs %s' % args)
  
  args.start_epoch = 0

  # log init
  fig = [(["train_acc", "test_acc"], "epochs", "accuracy[%]"),
         (["train_loss", "test_loss"], "epochs", "loss")]
  store = Store(dir=args.save, name="store", fig=fig)
  args.store = store

  set_seed(args.seed)

In [41]:
@Experiment.event('setup')
def set_device(args):
  # cuda init
  args.use_cuda = torch.cuda.is_available()
  args.device = torch.device("cuda" if args.use_cuda else "cpu")

In [42]:
@Experiment.event('start')
def set_tdqm(args):
  args.bar = tqdm(total = args.epochs - args.start_epoch)

In [43]:
@Experiment.event('epoch_end')
def update_tdqm(args):
  args.bar.update()

In [44]:
@Experiment.event('epoch_start')
def logging_alpha(args):
  model = args.model
  if not model.evaluate:
    logging.info('raw %s', model.matrix_alpha(normalize=False))
    logging.info('alpha %s', model.matrix_alpha())
    logging.info('beta %s', model.beta())

In [45]:
@Experiment.event('train_end')
def train_end(args, data):
  train_acc, train_obj = data
  logging.info('train_acc %f', train_acc)
  args.store.add("train_loss", train_obj)
  args.store.add("train_acc", train_acc)

In [46]:
@Experiment.event('test_end')
def test_end(args, data):
  valid_acc, valid_obj = data
  logging.info('valid_acc %f', valid_acc)
  args.store.add("test_loss", valid_obj)
  args.store.add("test_acc", valid_acc)

In [47]:
@Experiment.event('checkpoint', 'end', order=1)
def save_checkpoint(args):
  args.store.save()
  save_dir(args.save)

In [48]:
@Experiment.event('checkpoint', 'end')
def save_model(args):
  state = {
    'model': args.model.state_dict(),
    'graph': args.model.graph,
    'alpha': args.model.alphas,
    'store': args.store,
    'epoch': args.epoch,
  }
  torch.save(state, os.path.join(args.save, 'checkpoint.pth'))

In [49]:
@Experiment.event('epoch_end')
def save_graph(args):
  path = os.path.join(args.save, 'graph')
  if not os.path.exists(path):
    os.mkdir(path)
  sampler = MaxSampler()
  graph = args.model.sampled_graph(sampler)
  torch.save(graph, os.path.join(path, 'graph_%d.pth' % args.epoch))
  render_graph(graph, os.path.join(path, 'graph_%d' % args.epoch))

In [50]:
@Experiment.event('epoch_end')
def save_alpha(args):
  path = os.path.join(args.save, 'alpha')
  if not os.path.exists(path):
    os.mkdir(path)
  save_heatmap(args.model.matrix_alpha(), os.path.join(path, 'alpha_%d.png' % args.epoch))

In [51]:
@Experiment.event('epoch_end')
def save_weight(args):
  if not args.save_weight: return

  path = os.path.join(args.save, 'model')
  if not os.path.exists(path):
    os.mkdir(path)
  torch.save(args.model.state_dict(), os.path.join(path, 'model_%d.pth' % args.epoch))

In [52]:
@Experiment.event('end')
def aggregate_data(args):
  store = args.store
  m = store.apply('test_acc', max)
  store.add('test_acc' + '_best', m)
  logging.info('best acc %s' % m)

## learning

In [53]:
@argspace(retain_graph=True)
def train(args):
  objs = AvgrageMeter()
  top1 = AvgrageMeter()
  top5 = AvgrageMeter()
  args.model.train()

  for step, (input, target) in enumerate(args.dataset):
    n = input.size(0)

    input = Variable(input, requires_grad=False).to(args.device)
    target = Variable(target, requires_grad=False).to(args.device)

    args.architect.step()

    args.optimizer.zero_grad()
    logits = args.model(input)
    loss = args.criterion(logits, target)

    loss.backward(retain_graph=args.retain_graph)
    # nn.utils.clip_grad_norm(args.model.parameters(), args.grad_clip)
    args.optimizer.step()

    prec1, prec5 = accuracy(logits, target, topk=(1, 5))
    objs.update(loss.item(), n)
    top1.update(prec1.item(), n)
    top5.update(prec5.item(), n)

    if step % args.report_freq == 0:
      logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

  return top1.avg, objs.avg

In [54]:
@argspace()
def infer(args):
  objs = AvgrageMeter()
  top1 = AvgrageMeter()
  top5 = AvgrageMeter()
  args.model.eval()

  for step, (input, target) in enumerate(args.dataset):
    input = Variable(input, requires_grad=False).to(args.device)
    target = Variable(target, requires_grad=False).to(args.device)

    logits = args.model(input)
    loss = args.criterion(logits, target)

    prec1, prec5 = accuracy(logits, target, topk=(1, 5))
    n = input.size(0)
    objs.update(loss.item(), n)
    top1.update(prec1.item(), n)
    top5.update(prec5.item(), n)

    if step % args.report_freq == 0:
      logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

  return top1.avg, objs.avg

## setting

In [55]:
def default_setting():
  default_args = argparse.Namespace(
      gene='VGG19', momentum=0.9, report=100, checkpoint=10,
      stride_max=2, batch_size=64, dir="", graph=None, lr_alpha=0.0,
      save_weight=False, seed=41, scheduler=None)
  return vars(default_args)

## experiment

In [56]:
@argspace(default_setting)
def load_model(args):
  
  # init
  set_device(args)
  device = args.device

  # model setup
  gene = cfg[args.gene]
  sampler = StrideCutSampler(args.stride_max) if args.stride_max > 0 else None
  model = Network(gene, graph=args.graph, preprocess=sampler).to(device).init_alpha(device)

  # resume
  if args.dir:
    state = torch.load(os.path.join(args.dir, 'checkpoint.pth'))
    model.load_state_dict(state['model'])
    model.alphas = state['alpha']

  return model

In [57]:
@argspace(default_setting)
def main(args):

  # init
  exp = Experiment()
  exp('setup')(args)
  store = args.store
  device = args.device


  # model setup
  gene = cfg[args.gene]
  logging.info('gene %s', gene)
  sampler = StrideCutSampler(args.stride_max) if args.stride_max > 0 else None
  graph = args.graph if args.graph else None
  model = Network(gene, graph=graph, preprocess=sampler).to(device).init_alpha(device)
  logging.info('model param %s', count_param(model))

  # load cuda
  if device == 'cuda':
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True

  # resume
  if args.dir:
    state = torch.load(os.path.join(args.dir, 'checkpoint.pth'))
    model.load_state_dict(state['model'])
    model.alphas = state['alpha']
    args.start_epoch = state['epoch']
    store.update(state['store'])
    logging.info('Resuming from epoch %d in %s' % (args.start_epoch, args.dir))

  args.model = model

  # env
  loader = load_dataloader(args)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5e-4)
  criterion = nn.CrossEntropyLoss()
  architect = Architect(loader.valid, model, criterion, args.lr_alpha, device)
  if args.scheduler:
    scheduler = args.scheduler(optimizer, **args.scheduler_args)

  exp('start')(args)

  for epoch in range(args.start_epoch + 1, args.epochs + 1):
    logging.info('epoch %d', epoch)
    args.epoch = epoch
    
    exp('epoch_start')(args)
          
    if not model.evaluate:
      architect.train = epoch >= args.weight_epoch

    if model.evaluate:
      model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

    # training
    train_result = train(dataset=loader.train, model=model, 
                         criterion=criterion, optimizer=optimizer, 
                         architect=architect, 
                         device=device, report_freq=args.report)
    exp('train_end')(args, train_result)

    # validation
    valid_result = infer(dataset=loader.test, model=model, 
                                 criterion=criterion,
                                 device=device, report_freq=args.report)
    exp('test_end')(args, valid_result)

    if args.scheduler:
      scheduler.step()

    exp('epoch_end')(args)

    if epoch % args.checkpoint == 0:
      exp('checkpoint')(args)

  exp('end')(args)

  return model

In [58]:
@argspace(**default_setting())
def evaluate(args):
  model = load_model(dir=args.dir, gene=args.gene, stride_max=args.stride_max)
  
  sampler : ArchitectureSampler = args.sampler
  graph = model.sampled_graph(sampler)

  args.origin_dir = args.dir
  args.dir = ""
  args.graph = graph
  main(**vars(args))

In [59]:
@argspace(**default_setting())
def evaluate_random(args):
  model = load_model(dir=args.dir, gene=args.gene, stride_max=args.stride_max)
  
  sampler : ArchitectureSampler = args.sampler
  graph = model.sampled_graph(sampler)

  n = len([(e, f) for e, f in graph.edges() if not e + 1 == f])
  sampler : ArchitectureSampler = RandomSampler(n)
  graph = model.sampled_graph(sampler)
  
  args.origin_dir = args.dir
  args.dir = ""
  args.graph = graph
  main(**vars(args))

# main

In [60]:
model = load_model()

In [61]:
def shortcut(graph):
  edges = [(e, f) for e, f in graph.edges() if not e + 1 == f]
  return len(edges)

In [62]:
shortcut(model.graph)

61

In [79]:
if __name__ == '__main__':
  for i in [7]:
    main(save="exp_vgg19_search%s" % i, lr=0.01, lr_alpha=0.003, epochs=50, weight_epoch=10,
        train_size=25000, valid_size=25000, test_size=5000, save_weight=False, seed=(i+40)
        , cutout=True, cutout_length=16, drop_path_prob=0.2)

Experiment dir : exp_vgg19_search7-2020-11-05_11-32-45
11/05 02:32:45 AM kwargs Namespace(batch_size=64, checkpoint=10, cutout=True, cutout_length=16, dir='', drop_path_prob=0.2, epochs=50, gene='VGG19', graph=None, lr=0.01, lr_alpha=0.003, momentum=0.9, origin_dir='drive/My Drive/ml/exp_vgg19_search5-2020-10-13_11-30-21', report=100, sampler=<__main__.MaxSampler object at 0x7f40f1fb1630>, save='exp_vgg19_search7-2020-11-05_11-32-45', save_weight=False, scheduler=None, scheduler_args={'T_max': 150, 'eta_min': 0.001}, seed=47, stride_max=2, test_size=5000, train_size=25000, valid_size=25000, weight_epoch=10)
11/05 02:32:45 AM gene [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
11/05 02:32:45 AM model param 26298058
Files already downloaded and verified
Files already downloaded and verified


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

11/05 02:32:47 AM epoch 1
11/05 02:32:47 AM raw tensor([[-1.0459e-03, -1.3514e-03, -7.9233e-04,  5.5662e-04,  1.7803e-03,
          9.2536e-04,  1.0310e-03,  5.6825e-04,  1.2906e-03, -9.4952e-04,
          1.2895e-03, -1.6469e-05,  6.8789e-04,  2.1701e-03, -2.1035e-03,
          1.8069e-03, -3.0924e-04],
        [ 2.0111e-04, -1.2426e-03,  1.7095e-03,  1.5458e-03,  9.5412e-04,
         -7.5734e-04,  6.7336e-04, -5.9593e-04,  1.2117e-03, -7.5062e-04,
         -5.5599e-04, -6.2765e-04, -4.6884e-04, -5.0285e-04,  3.4693e-04,
         -2.3207e-03,  7.5527e-04],
        [ 6.1144e-05,  2.4026e-04,  2.8717e-04, -1.2635e-03,  5.9574e-04,
         -1.0865e-03, -2.8315e-04,  4.5383e-04,  6.6502e-04,  1.2397e-03,
         -6.6251e-06,  8.4269e-04, -1.0770e-03,  9.8608e-04,  6.2118e-04,
          3.5590e-04,  2.1226e-05],
        [-9.3898e-04,  1.0582e-03,  3.4518e-04,  5.3831e-04, -1.1974e-03,
          1.0898e-03,  5.2106e-04,  1.5409e-05, -1.1069e-03,  3.3361e-04,
         -5.0325e-05,  1.3131e

RuntimeError: ignored

In [None]:
if __name__ == '__main__':
  paths = [(5, 'exp_vgg19_search5-2020-10-13_11-30-21')]
  for i, path in paths:
    evaluate(save="exp_vgg19_eval_r%d" % i, lr=0.025, epochs=150,
        scheduler=optim.lr_scheduler.CosineAnnealingLR,
        scheduler_args={'T_max':150, 'eta_min':0.001},
        train_size=50000, valid_size=0, test_size=10000, sampler=MaxSampler(),
        dir=os.path.join("drive/My Drive/ml/", path), cutout=True, cutout_length=16, drop_path_prob=0.2)
    output.clear()

In [None]:
if __name__ == '__main__':
  path = 'exp_vgg19_search5-2020-10-13_11-30-21'
  for i in [2, 3, 4]:
    evaluate(save="exp_vgg19_base%d" % i, lr=0.0090131, epochs=150,
        scheduler=optim.lr_scheduler.StepLR,
        scheduler_args={'gamma':0.23440, 'step_size':100},
        train_size=50000, valid_size=0, test_size=10000, sampler=ForwardSampler(),
        dir=os.path.join("drive/My Drive/ml/", path))
    output.clear()