In [1]:
from functools import partial
from collections import OrderedDict

In [2]:
%config InlineBackend.figure_format = 'retina'

import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision as tv

In [28]:
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch import optim

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
import requests
import io

In [6]:
def get_weights(bit_variant):
  response = requests.get(f'https://storage.googleapis.com/bit_models/{bit_variant}.npz')
  response.raise_for_status()
  return np.load(io.BytesIO(response.content))

In [7]:
class StdConv2d(nn.Conv2d):
  def forward(self, x):
    w = self.weight
    v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
    w = (w - m) / torch.sqrt(v + 1e-10)
    return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)

In [8]:
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
  return StdConv2d(cin, cout, kernel_size=3, stride=stride, padding=1, bias=bias, groups=groups)

def conv1x1(cin, cout, stride=1, bias=False):
  return StdConv2d(cin, cout, kernel_size=1, stride=stride, padding=0, bias=bias)

In [9]:
def tf2th(conv_weights):
  """Possibly convert HWIO to OIHW"""
  if conv_weights.ndim == 4:
    conv_weights = np.transpose(conv_weights, [3, 2, 0, 1])
  return torch.from_numpy(conv_weights)

In [10]:
class PreActBottleneck(nn.Module):
  """
  Follows the implementation of "Identity Mappings in Deep Residual Networks" here:
  https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua

  Except it puts the stride on 3x3 conv when available.
  """
  def __init__(self, cin, cout=None, cmid=None, stride=1):
    super().__init__()
    cout = cout or cin
    cmid = cmid or cout//4

    self.gn1 = nn.GroupNorm(32, cin)
    self.conv1 = conv1x1(cin, cmid)
    self.gn2 = nn.GroupNorm(32, cmid)
    self.conv2 = conv3x3(cmid, cmid, stride)  # Original ResNetv2 has it on conv1!!
    self.gn3 = nn.GroupNorm(32, cmid)
    self.conv3 = conv1x1(cmid, cout)
    self.relu = nn.ReLU(inplace=True)

    if (stride != 1 or cin != cout):
      # Projection also with pre-activation according to paper.
      self.downsample = conv1x1(cin, cout, stride)

  def forward(self, x):
      # Conv'ed branch
      out = self.relu(self.gn1(x))

      # Residual branch
      residual = x
      if hasattr(self, 'downsample'):
          residual = self.downsample(out)

      # The first block has already applied pre-act before splitting, see Appendix.
      out = self.conv1(out)
      out = self.conv2(self.relu(self.gn2(out)))
      out = self.conv3(self.relu(self.gn3(out)))

      return out + residual

  def load_from(self, weights, prefix=''):
    with torch.no_grad():
      self.conv1.weight.copy_(tf2th(weights[prefix + 'a/standardized_conv2d/kernel']))
      self.conv2.weight.copy_(tf2th(weights[prefix + 'b/standardized_conv2d/kernel']))
      self.conv3.weight.copy_(tf2th(weights[prefix + 'c/standardized_conv2d/kernel']))
      self.gn1.weight.copy_(tf2th(weights[prefix + 'a/group_norm/gamma']))
      self.gn2.weight.copy_(tf2th(weights[prefix + 'b/group_norm/gamma']))
      self.gn3.weight.copy_(tf2th(weights[prefix + 'c/group_norm/gamma']))
      self.gn1.bias.copy_(tf2th(weights[prefix + 'a/group_norm/beta']))
      self.gn2.bias.copy_(tf2th(weights[prefix + 'b/group_norm/beta']))
      self.gn3.bias.copy_(tf2th(weights[prefix + 'c/group_norm/beta']))
      if hasattr(self, 'downsample'):
        self.downsample.weight.copy_(tf2th(weights[prefix + 'a/proj/standardized_conv2d/kernel']))
    return self

In [11]:
class ResNetV2(nn.Module):
  BLOCK_UNITS = {
      'r50': [3, 4, 6, 3],
      'r101': [3, 4, 23, 3],
      'r152': [3, 8, 36, 3],
  }

  def __init__(self, block_units, width_factor, head_size=21843, zero_head=False):
    super().__init__()
    wf = width_factor  # shortcut 'cause we'll use it a lot.

    self.root = nn.Sequential(OrderedDict([
        ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)),
        ('padp', nn.ConstantPad2d(1, 0)),
        ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
        # The following is subtly not the same!
        #('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
    ]))

    self.body = nn.Sequential(OrderedDict([
        ('block1', nn.Sequential(OrderedDict(
            [('unit01', PreActBottleneck(cin= 64*wf, cout=256*wf, cmid=64*wf))] +
            [(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)],
        ))),
        ('block2', nn.Sequential(OrderedDict(
            [('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] +
            [(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)],
        ))),
        ('block3', nn.Sequential(OrderedDict(
            [('unit01', PreActBottleneck(cin= 512*wf, cout=1024*wf, cmid=256*wf, stride=2))] +
            [(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)],
        ))),
        ('block4', nn.Sequential(OrderedDict(
            [('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] +
            [(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)],
        ))),
    ]))

    self.zero_head = zero_head
    self.head = nn.Sequential(OrderedDict([
        ('gn', nn.GroupNorm(32, 2048*wf)),
        ('relu', nn.ReLU(inplace=True)),
        ('avg', nn.AdaptiveAvgPool2d(output_size=1)),
        ('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)),
    ]))
  
  def forward(self, x):
    x = (self.body(self.root(x)))
    #assert x.shape[-2:] == (1, 1)  # We should have no spatial shape left.
    #return x[...,0,0]
    return x

  def load_from(self, weights, prefix='resnet/'):
    with torch.no_grad():
      self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
      self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
      self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
      if self.zero_head:
        nn.init.zeros_(self.head.conv.weight)
        nn.init.zeros_(self.head.conv.bias)
      else:
        self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
        self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))

      for bname, block in self.body.named_children():
        for uname, unit in block.named_children():
          unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/')
    return self

In [12]:
from IPython.display import HTML, display

def progress(value, max=100):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 100%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

In [13]:
def stairs(s, v, *svs):
    """ Implements a typical "stairs" schedule for learning-rates.
    Best explained by example:
    stairs(s, 0.1, 10, 0.01, 20, 0.001)
    will return 0.1 if s<10, 0.01 if 10<=s<20, and 0.001 if 20<=s
    """
    for s0, v0 in zip(svs[::2], svs[1::2]):
        if s < s0:
            break
        v = v0
    return v

def rampup(s, peak_s, peak_lr):
  if s < peak_s:  # Warmup
    return s/peak_s * peak_lr
  else:
    return peak_lr

def schedule(s):
  step_lr = stairs(s, 3e-3, 200, 3e-4, 300, 3e-5, 400, 3e-6, 500, None)
  return rampup(s, 100, step_lr)

In [14]:
import PIL

In [15]:
preprocess_train = tv.transforms.Compose([
    tv.transforms.Resize((160, 160), interpolation=PIL.Image.BILINEAR),  # It's the default, just being explicit for the reader.
    tv.transforms.RandomCrop((128, 128)),
    tv.transforms.RandomHorizontalFlip(),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Get data into [-1, 1]
])

preprocess_eval = tv.transforms.Compose([
    tv.transforms.Resize((128, 128), interpolation=PIL.Image.BILINEAR),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = tv.datasets.CIFAR100(root='./data', train=True, download=True, transform=preprocess_train)
testset = tv.datasets.CIFAR100(root='./data', train=False, download=True, transform=preprocess_eval)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [16]:
weights_cifar10 = get_weights('BiT-M-R50x1')

In [17]:
res = ResNetV2(ResNetV2.BLOCK_UNITS['r50'], width_factor=1, head_size=21843)  # NOTE: No new head.
res.load_from(weights_cifar10)
res.to(device);

In [24]:
class ClassificationHead(nn.Module):
    def __init__(self, wf, classes, dropout=0.1):
        super(ClassificationHead, self).__init__()
        self.wf = wf
        self.classes = classes
        self.gn = nn.GroupNorm(32, 2048*wf)
        self.relu = nn.ReLU(inplace=True)
        self.avg = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv = nn.Conv2d(2048*wf, classes, kernel_size=1, bias=True)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, inp):
        # inp: (batch_size, embed_dim)
        #batch_size, embed_dim = inp.size()
        #assert embed_dim == self.embed_dim

        out = self.dropout(self.avg(self.relu(self.gn(inp))))
        out = self.conv(out)

        # out: (batch_size, embed_dim) | SoftMaxed along the last dimension
        return out[...,0,0]

In [19]:
lr = 0.001
weight_decay = 0.0001
dropout = 0.3
batch_size = 64
num_workers = 2
shuffle = True
patch_size = 4
max_len = ((32//patch_size) * (32//patch_size)) + 1 # +1 for the class token
embed_dim = 128
classes = 100
layers = 6
heads = 8
epochs = 30

In [20]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [21]:
model_save_name = 'resnet.pt'
path = F"/content/gdrive/My Drive/{model_save_name}" 

In [29]:
model = ClassificationHead(1, classes, dropout).to(device)
dataloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr,weight_decay=weight_decay)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(dataloader), epochs=epochs)

model.train()
best_acc = 0
curr_acc = 0

for epoch in range(epochs):

    running_loss = 0.0
    running_accuracy = 0.0

    for data, target in tqdm(dataloader):
        optimizer.zero_grad()
        data = data.to(device)
        target = target.to(device)

        res_output = res(data)
        output = model(res_output)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        scheduler.step()

        acc = (output.argmax(dim=1) == target).float().mean()
        running_accuracy += acc / len(dataloader)
        running_loss += loss.item() / len(dataloader)
    curr_acc = running_accuracy
    if(best_acc < curr_acc):
      torch.save({
                'epoch':epoch,
                'optimizer': optimizer.state_dict(),
                'model': model.state_dict(),
                'train_loss': running_loss,
                'train_acc' : running_accuracy
            }, path)
      best_acc = curr_acc
    
    print(f"Epoch : {epoch+1} - loss : {running_loss:.4f} - acc: {running_accuracy:.4f}\n")

100%|██████████| 782/782 [03:29<00:00,  3.73it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 1 - loss : 3.4445 - acc: 0.3658



100%|██████████| 782/782 [03:36<00:00,  3.61it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 2 - loss : 1.6810 - acc: 0.6559



100%|██████████| 782/782 [03:38<00:00,  3.58it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 3 - loss : 1.0910 - acc: 0.7109



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 4 - loss : 0.9188 - acc: 0.7355



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 5 - loss : 0.8583 - acc: 0.7482



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 6 - loss : 0.8212 - acc: 0.7545



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 7 - loss : 0.8055 - acc: 0.7592



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 8 - loss : 0.7864 - acc: 0.7646



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 9 - loss : 0.7648 - acc: 0.7681



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 10 - loss : 0.7415 - acc: 0.7752



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 11 - loss : 0.7138 - acc: 0.7837



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 12 - loss : 0.6975 - acc: 0.7886



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 13 - loss : 0.6751 - acc: 0.7935



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 14 - loss : 0.6575 - acc: 0.7972



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 15 - loss : 0.6364 - acc: 0.8022



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 16 - loss : 0.6259 - acc: 0.8070



100%|██████████| 782/782 [03:38<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 17 - loss : 0.6079 - acc: 0.8114



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 18 - loss : 0.5840 - acc: 0.8175



100%|██████████| 782/782 [03:38<00:00,  3.58it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 19 - loss : 0.5819 - acc: 0.8188



100%|██████████| 782/782 [03:38<00:00,  3.58it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 20 - loss : 0.5628 - acc: 0.8229



100%|██████████| 782/782 [03:38<00:00,  3.58it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 21 - loss : 0.5514 - acc: 0.8262



100%|██████████| 782/782 [03:38<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 22 - loss : 0.5392 - acc: 0.8301



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 23 - loss : 0.5314 - acc: 0.8318



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 24 - loss : 0.5182 - acc: 0.8361



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 25 - loss : 0.5121 - acc: 0.8383



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 26 - loss : 0.5081 - acc: 0.8393



100%|██████████| 782/782 [03:38<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 27 - loss : 0.5065 - acc: 0.8394



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 28 - loss : 0.4963 - acc: 0.8444



100%|██████████| 782/782 [03:37<00:00,  3.59it/s]
  0%|          | 0/782 [00:00<?, ?it/s]

Epoch : 29 - loss : 0.4997 - acc: 0.8419



100%|██████████| 782/782 [03:38<00:00,  3.58it/s]

Epoch : 30 - loss : 0.4943 - acc: 0.8446






In [30]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = ClassificationHead(1, 100, dropout=0.3).to(device)
    model.load_state_dict(checkpoint['model'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    
    model.eval()
    
    return model

In [None]:
model = load_checkpoint(path)

In [31]:
def eval_cifar10(model, bs=100, progressbar=True):
  loader_test = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False, num_workers=2)

  model.eval()

  if progressbar is True:
    progressbar = display(progress(0, len(loader_test)), display_id=True)

  preds = []
  with torch.no_grad():
    for i, (x, t) in enumerate(loader_test):
      x, t = x.to(device), t.numpy()
      temp = res(x)
      logits = model(temp)
      _, y = torch.max(logits.data, 1)
      preds.extend(y.cpu().numpy() == t)
      progressbar.update(progress(i+1, len(loader_test)))

  return np.mean(preds)

In [32]:
print("Expected: 97.61%")
print(f"Accuracy: {eval_cifar10(model):.2%}")

Expected: 97.61%


Accuracy: 79.60%
