In [1]:
import os
import argparse
import socket
import time

import tensorboard_logger as tb_logger
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn


from models import model_dict
from models.util import Embed, ConvReg, LinearEmbed
from models.util import Connector, Translator, Paraphraser

from dataset.cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample

from helper.util import adjust_learning_rate

from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss
from distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss
from crd.criterion import CRDLoss
from ntk import NTKLoss

from helper.loops import train_distill as train, validate
from helper.pretrain import init

import numpy as np

# Utilities

In [2]:
# returns the number of parameters in the model or a part of the model
def num_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum([np.prod(p.size()) for p in model_parameters]) 

In [10]:
# Load the teacher
resnet50_rp = model_dict['resnet50_rp'](num_classes=100)
resnet50 = model_dict['ResNet50'](num_classes=100)

model_path = "save/models/ResNet50_vanilla/ckpt_epoch_240.pth"

resnet50.load_state_dict(torch.load(model_path)['model'])

<All keys matched successfully>

In [6]:
num_parameters(resnet50_rp)

23705252

In [9]:
num_parameters(resnet50)

23705252

In [11]:
resnet50_rp.load_state_dict(torch.load(model_path)['model'])

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.conv.weight", "conv1.bn.weight", "conv1.bn.bias", "conv1.bn.running_mean", "conv1.bn.running_var", "layer1.0.conv1.conv.weight", "layer1.0.conv1.bn.weight", "layer1.0.conv1.bn.bias", "layer1.0.conv1.bn.running_mean", "layer1.0.conv1.bn.running_var", "layer1.0.conv2.conv.weight", "layer1.0.conv2.bn.weight", "layer1.0.conv2.bn.bias", "layer1.0.conv2.bn.running_mean", "layer1.0.conv2.bn.running_var", "layer1.0.conv3.conv.weight", "layer1.0.conv3.bn.weight", "layer1.0.conv3.bn.bias", "layer1.0.conv3.bn.running_mean", "layer1.0.conv3.bn.running_var", "layer1.0.downsample.conv.weight", "layer1.0.downsample.bn.weight", "layer1.0.downsample.bn.bias", "layer1.0.downsample.bn.running_mean", "layer1.0.downsample.bn.running_var", "layer1.1.conv1.conv.weight", "layer1.1.conv1.bn.weight", "layer1.1.conv1.bn.bias", "layer1.1.conv1.bn.running_mean", "layer1.1.conv1.bn.running_var", "layer1.1.conv2.conv.weight", "layer1.1.conv2.bn.weight", "layer1.1.conv2.bn.bias", "layer1.1.conv2.bn.running_mean", "layer1.1.conv2.bn.running_var", "layer1.1.conv3.conv.weight", "layer1.1.conv3.bn.weight", "layer1.1.conv3.bn.bias", "layer1.1.conv3.bn.running_mean", "layer1.1.conv3.bn.running_var", "layer1.2.conv1.conv.weight", "layer1.2.conv1.bn.weight", "layer1.2.conv1.bn.bias", "layer1.2.conv1.bn.running_mean", "layer1.2.conv1.bn.running_var", "layer1.2.conv2.conv.weight", "layer1.2.conv2.bn.weight", "layer1.2.conv2.bn.bias", "layer1.2.conv2.bn.running_mean", "layer1.2.conv2.bn.running_var", "layer1.2.conv3.conv.weight", "layer1.2.conv3.bn.weight", "layer1.2.conv3.bn.bias", "layer1.2.conv3.bn.running_mean", "layer1.2.conv3.bn.running_var", "layer2.0.conv1.conv.weight", "layer2.0.conv1.bn.weight", "layer2.0.conv1.bn.bias", "layer2.0.conv1.bn.running_mean", "layer2.0.conv1.bn.running_var", "layer2.0.conv2.conv.weight", "layer2.0.conv2.bn.weight", "layer2.0.conv2.bn.bias", "layer2.0.conv2.bn.running_mean", "layer2.0.conv2.bn.running_var", "layer2.0.conv3.conv.weight", "layer2.0.conv3.bn.weight", "layer2.0.conv3.bn.bias", "layer2.0.conv3.bn.running_mean", "layer2.0.conv3.bn.running_var", "layer2.0.downsample.conv.weight", "layer2.0.downsample.bn.weight", "layer2.0.downsample.bn.bias", "layer2.0.downsample.bn.running_mean", "layer2.0.downsample.bn.running_var", "layer2.1.conv1.conv.weight", "layer2.1.conv1.bn.weight", "layer2.1.conv1.bn.bias", "layer2.1.conv1.bn.running_mean", "layer2.1.conv1.bn.running_var", "layer2.1.conv2.conv.weight", "layer2.1.conv2.bn.weight", "layer2.1.conv2.bn.bias", "layer2.1.conv2.bn.running_mean", "layer2.1.conv2.bn.running_var", "layer2.1.conv3.conv.weight", "layer2.1.conv3.bn.weight", "layer2.1.conv3.bn.bias", "layer2.1.conv3.bn.running_mean", "layer2.1.conv3.bn.running_var", "layer2.2.conv1.conv.weight", "layer2.2.conv1.bn.weight", "layer2.2.conv1.bn.bias", "layer2.2.conv1.bn.running_mean", "layer2.2.conv1.bn.running_var", "layer2.2.conv2.conv.weight", "layer2.2.conv2.bn.weight", "layer2.2.conv2.bn.bias", "layer2.2.conv2.bn.running_mean", "layer2.2.conv2.bn.running_var", "layer2.2.conv3.conv.weight", "layer2.2.conv3.bn.weight", "layer2.2.conv3.bn.bias", "layer2.2.conv3.bn.running_mean", "layer2.2.conv3.bn.running_var", "layer2.3.conv1.conv.weight", "layer2.3.conv1.bn.weight", "layer2.3.conv1.bn.bias", "layer2.3.conv1.bn.running_mean", "layer2.3.conv1.bn.running_var", "layer2.3.conv2.conv.weight", "layer2.3.conv2.bn.weight", "layer2.3.conv2.bn.bias", "layer2.3.conv2.bn.running_mean", "layer2.3.conv2.bn.running_var", "layer2.3.conv3.conv.weight", "layer2.3.conv3.bn.weight", "layer2.3.conv3.bn.bias", "layer2.3.conv3.bn.running_mean", "layer2.3.conv3.bn.running_var", "layer3.0.conv1.conv.weight", "layer3.0.conv1.bn.weight", "layer3.0.conv1.bn.bias", "layer3.0.conv1.bn.running_mean", "layer3.0.conv1.bn.running_var", "layer3.0.conv2.conv.weight", "layer3.0.conv2.bn.weight", "layer3.0.conv2.bn.bias", "layer3.0.conv2.bn.running_mean", "layer3.0.conv2.bn.running_var", "layer3.0.conv3.conv.weight", "layer3.0.conv3.bn.weight", "layer3.0.conv3.bn.bias", "layer3.0.conv3.bn.running_mean", "layer3.0.conv3.bn.running_var", "layer3.0.downsample.conv.weight", "layer3.0.downsample.bn.weight", "layer3.0.downsample.bn.bias", "layer3.0.downsample.bn.running_mean", "layer3.0.downsample.bn.running_var", "layer3.1.conv1.conv.weight", "layer3.1.conv1.bn.weight", "layer3.1.conv1.bn.bias", "layer3.1.conv1.bn.running_mean", "layer3.1.conv1.bn.running_var", "layer3.1.conv2.conv.weight", "layer3.1.conv2.bn.weight", "layer3.1.conv2.bn.bias", "layer3.1.conv2.bn.running_mean", "layer3.1.conv2.bn.running_var", "layer3.1.conv3.conv.weight", "layer3.1.conv3.bn.weight", "layer3.1.conv3.bn.bias", "layer3.1.conv3.bn.running_mean", "layer3.1.conv3.bn.running_var", "layer3.2.conv1.conv.weight", "layer3.2.conv1.bn.weight", "layer3.2.conv1.bn.bias", "layer3.2.conv1.bn.running_mean", "layer3.2.conv1.bn.running_var", "layer3.2.conv2.conv.weight", "layer3.2.conv2.bn.weight", "layer3.2.conv2.bn.bias", "layer3.2.conv2.bn.running_mean", "layer3.2.conv2.bn.running_var", "layer3.2.conv3.conv.weight", "layer3.2.conv3.bn.weight", "layer3.2.conv3.bn.bias", "layer3.2.conv3.bn.running_mean", "layer3.2.conv3.bn.running_var", "layer3.3.conv1.conv.weight", "layer3.3.conv1.bn.weight", "layer3.3.conv1.bn.bias", "layer3.3.conv1.bn.running_mean", "layer3.3.conv1.bn.running_var", "layer3.3.conv2.conv.weight", "layer3.3.conv2.bn.weight", "layer3.3.conv2.bn.bias", "layer3.3.conv2.bn.running_mean", "layer3.3.conv2.bn.running_var", "layer3.3.conv3.conv.weight", "layer3.3.conv3.bn.weight", "layer3.3.conv3.bn.bias", "layer3.3.conv3.bn.running_mean", "layer3.3.conv3.bn.running_var", "layer3.4.conv1.conv.weight", "layer3.4.conv1.bn.weight", "layer3.4.conv1.bn.bias", "layer3.4.conv1.bn.running_mean", "layer3.4.conv1.bn.running_var", "layer3.4.conv2.conv.weight", "layer3.4.conv2.bn.weight", "layer3.4.conv2.bn.bias", "layer3.4.conv2.bn.running_mean", "layer3.4.conv2.bn.running_var", "layer3.4.conv3.conv.weight", "layer3.4.conv3.bn.weight", "layer3.4.conv3.bn.bias", "layer3.4.conv3.bn.running_mean", "layer3.4.conv3.bn.running_var", "layer3.5.conv1.conv.weight", "layer3.5.conv1.bn.weight", "layer3.5.conv1.bn.bias", "layer3.5.conv1.bn.running_mean", "layer3.5.conv1.bn.running_var", "layer3.5.conv2.conv.weight", "layer3.5.conv2.bn.weight", "layer3.5.conv2.bn.bias", "layer3.5.conv2.bn.running_mean", "layer3.5.conv2.bn.running_var", "layer3.5.conv3.conv.weight", "layer3.5.conv3.bn.weight", "layer3.5.conv3.bn.bias", "layer3.5.conv3.bn.running_mean", "layer3.5.conv3.bn.running_var", "layer4.0.conv1.conv.weight", "layer4.0.conv1.bn.weight", "layer4.0.conv1.bn.bias", "layer4.0.conv1.bn.running_mean", "layer4.0.conv1.bn.running_var", "layer4.0.conv2.conv.weight", "layer4.0.conv2.bn.weight", "layer4.0.conv2.bn.bias", "layer4.0.conv2.bn.running_mean", "layer4.0.conv2.bn.running_var", "layer4.0.conv3.conv.weight", "layer4.0.conv3.bn.weight", "layer4.0.conv3.bn.bias", "layer4.0.conv3.bn.running_mean", "layer4.0.conv3.bn.running_var", "layer4.0.downsample.conv.weight", "layer4.0.downsample.bn.weight", "layer4.0.downsample.bn.bias", "layer4.0.downsample.bn.running_mean", "layer4.0.downsample.bn.running_var", "layer4.1.conv1.conv.weight", "layer4.1.conv1.bn.weight", "layer4.1.conv1.bn.bias", "layer4.1.conv1.bn.running_mean", "layer4.1.conv1.bn.running_var", "layer4.1.conv2.conv.weight", "layer4.1.conv2.bn.weight", "layer4.1.conv2.bn.bias", "layer4.1.conv2.bn.running_mean", "layer4.1.conv2.bn.running_var", "layer4.1.conv3.conv.weight", "layer4.1.conv3.bn.weight", "layer4.1.conv3.bn.bias", "layer4.1.conv3.bn.running_mean", "layer4.1.conv3.bn.running_var", "layer4.2.conv1.conv.weight", "layer4.2.conv1.bn.weight", "layer4.2.conv1.bn.bias", "layer4.2.conv1.bn.running_mean", "layer4.2.conv1.bn.running_var", "layer4.2.conv2.conv.weight", "layer4.2.conv2.bn.weight", "layer4.2.conv2.bn.bias", "layer4.2.conv2.bn.running_mean", "layer4.2.conv2.bn.running_var", "layer4.2.conv3.conv.weight", "layer4.2.conv3.bn.weight", "layer4.2.conv3.bn.bias", "layer4.2.conv3.bn.running_mean", "layer4.2.conv3.bn.running_var", "fc.weight", "fc.bias". 
	Unexpected key(s) in state_dict: "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "linear.weight", "linear.bias", "conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.shortcut.0.weight", "layer1.0.shortcut.1.weight", "layer1.0.shortcut.1.bias", "layer1.0.shortcut.1.running_mean", "layer1.0.shortcut.1.running_var", "layer1.0.conv1.weight", "layer1.0.conv2.weight", "layer1.0.conv3.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.1.conv1.weight", "layer1.1.conv2.weight", "layer1.1.conv3.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.conv2.weight", "layer1.2.conv3.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.shortcut.0.weight", "layer2.0.shortcut.1.weight", "layer2.0.shortcut.1.bias", "layer2.0.shortcut.1.running_mean", "layer2.0.shortcut.1.running_var", "layer2.0.conv1.weight", "layer2.0.conv2.weight", "layer2.0.conv3.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.1.conv1.weight", "layer2.1.conv2.weight", "layer2.1.conv3.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.2.conv1.weight", "layer2.2.conv2.weight", "layer2.2.conv3.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer2.3.conv1.weight", "layer2.3.conv2.weight", "layer2.3.conv3.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.shortcut.0.weight", "layer3.0.shortcut.1.weight", "layer3.0.shortcut.1.bias", "layer3.0.shortcut.1.running_mean", "layer3.0.shortcut.1.running_var", "layer3.0.conv1.weight", "layer3.0.conv2.weight", "layer3.0.conv3.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.1.conv1.weight", "layer3.1.conv2.weight", "layer3.1.conv3.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.2.conv1.weight", "layer3.2.conv2.weight", "layer3.2.conv3.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.3.conv1.weight", "layer3.3.conv2.weight", "layer3.3.conv3.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.4.conv1.weight", "layer3.4.conv2.weight", "layer3.4.conv3.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer3.5.conv1.weight", "layer3.5.conv2.weight", "layer3.5.conv3.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.shortcut.0.weight", "layer4.0.shortcut.1.weight", "layer4.0.shortcut.1.bias", "layer4.0.shortcut.1.running_mean", "layer4.0.shortcut.1.running_var", "layer4.0.conv1.weight", "layer4.0.conv2.weight", "layer4.0.conv3.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.1.conv1.weight", "layer4.1.conv2.weight", "layer4.1.conv3.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "layer4.2.conv1.weight", "layer4.2.conv2.weight", "layer4.2.conv3.weight". 

In [13]:
resnet50_rp

ResNet(
  (conv1): ConvBlock(
    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=False, momentum=0.1, affine=True, track_running_stats=True)
  )
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): ConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=False, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv2): ConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=False, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv3): ConvBlock(
        (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(256, eps=False, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
      (downsample): ConvBlock(
        (conv

In [14]:
resnet50

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (