In [None]:
import os
import argparse
import socket
import time

import tensorboard_logger as tb_logger
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn


from models import model_dict
from models.util import Embed, ConvReg, LinearEmbed
from models.util import Connector, Translator, Paraphraser

from dataset.cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample

from helper.util import adjust_learning_rate

from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss
from distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss
from crd.criterion import CRDLoss
from ntk import NTKLoss

from helper.loops import train_distill as train, validate
from helper.pretrain import init

import numpy as np

# Utilities

In [None]:
# returns the number of parameters in the model or a part of the model
def num_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum([np.prod(p.size()) for p in model_parameters]) 

In [None]:
# Load the teacher
resnet50_rp = model_dict['resnet50_rp'](num_classes=100)
resnet50 = model_dict['ResNet50'](num_classes=100)

model_path = "save/models/ResNet50_vanilla/ckpt_epoch_240.pth"

resnet50.load_state_dict(torch.load(model_path)['model'])

In [None]:
num_parameters(resnet50_rp)

In [None]:
num_parameters(resnet50)

# RP Tests

In [None]:
import torch
import torch.autograd as autograd
from models.modules import *

In [None]:
##### TEST 1: Conv2d, ReLU, AvgPool2d and Linear #####
x = torch.randn(2, 3, 8, 8)
conv1 = Conv2d(3, 8, kernel_size=3, stride=1, padding=1, grad_proj=True)
conv2 = Conv2d(8, 8, kernel_size=3, stride=2, padding=1, grad_proj=True)
relu = ReLU(inplace=True, grad_proj=True)
pool = AvgPool2d(kernel_size=4, grad_proj=True)
fc = Linear(8, 2, grad_proj=True)

# forward pass
y, jvp = pool(*relu(*conv2(*relu(*conv1(x)))))
y, jvp = fc(torch.flatten(y, 1), torch.flatten(jvp, 1))

# backward pass
grad = autograd.grad(y[0, 0], 
  list(conv1.parameters()) 
  + list(conv2.parameters()) 
  + list(fc.parameters()))

# parameter gradients
conv1_weight_grad, conv1_bias_grad = grad[0], grad[1]
conv2_weight_grad, conv2_bias_grad = grad[2], grad[3]
fc_weight_grad, fc_bias_grad = grad[4], grad[5]

# random vectors
conv1_weight_rv, conv1_bias_rv = conv1.weight_rv, conv1.bias_rv
conv2_weight_rv, conv2_bias_rv = conv2.weight_rv, conv2.bias_rv
fc_weight_rv, fc_bias_rv = fc.weight_rv, fc.bias_rv

# brute-force calculation of Jacobian-vector product
true_jvp = (conv1_weight_grad * conv1_weight_rv).sum() \
        + (conv1_bias_grad * conv1_bias_rv).sum() \
        + (conv2_weight_grad * conv2_weight_rv).sum() \
        + (conv2_bias_grad * conv2_bias_rv).sum() \
        + (fc_weight_grad * fc_weight_rv).sum() \
        + (fc_bias_grad * fc_bias_rv).sum()

if abs(jvp[0, 0] - true_jvp) < 1e-5:
    print('TEST 1 PASSED')
else:
    print('TEST 1 FAILED')
    print('ground truth: %.3f 	 result: %.3f' % (true_jvp, jvp[0, 0]))
    exit()

In [None]:
##### TEST 3: Conv2d, BatchNorm2d (train), ReLU, AvgPool2d and Linear #####
x = torch.randn(16, 3, 8, 8)
conv1 = Conv2d(3, 8, kernel_size=3, stride=1, padding=1, grad_proj=True)
bn1 = BatchNorm2d(8, grad_proj=True)
conv2 = Conv2d(8, 8, kernel_size=3, stride=2, padding=1, grad_proj=True)
bn2 = BatchNorm2d(8, grad_proj=True)
relu = ReLU(inplace=True, grad_proj=True)
pool = AvgPool2d(kernel_size=4, grad_proj=True)
fc = Linear(8, 4, grad_proj=True)

# set to training mode
bn1.train()
bn2.train()

# forward pass
y, jvp = pool(*relu(*bn2(*conv2(*relu(*bn1(*conv1(x)))))))
y, jvp = fc(torch.flatten(y, 1), torch.flatten(jvp, 1))

# backward pass
grad = autograd.grad(y[0, 0], 
  list(conv1.parameters()) + list(bn1.parameters())
  + list(conv2.parameters()) + list(bn2.parameters())
  + list(fc.parameters()))

# parameter gradients
conv1_weight_grad, conv1_bias_grad = grad[0], grad[1]
bn1_weight_grad, bn1_bias_grad = grad[2], grad[3]
conv2_weight_grad, conv2_bias_grad = grad[4], grad[5]
bn2_weight_grad, bn2_bias_grad = grad[6], grad[7]
fc_weight_grad, fc_bias_grad = grad[8], grad[9]

# random vectors
conv1_weight_rv, conv1_bias_rv = conv1.weight_rv, conv1.bias_rv
bn1_weight_rv, bn1_bias_rv = bn1.weight_rv, bn1.bias_rv
conv2_weight_rv, conv2_bias_rv = conv2.weight_rv, conv2.bias_rv
bn2_weight_rv, bn2_bias_rv = bn2.weight_rv, bn2.bias_rv
fc_weight_rv, fc_bias_rv = fc.weight_rv, fc.bias_rv

# brute-force calculation of Jacobian-vector product
true_jvp = (conv1_weight_grad * conv1_weight_rv).sum() \
         + (conv1_bias_grad * conv1_bias_rv).sum() \
         + (bn1_weight_grad * bn1_weight_rv).sum() \
         + (bn1_bias_grad * bn1_bias_rv).sum() \
         + (conv2_weight_grad * conv2_weight_rv).sum() \
         + (conv2_bias_grad * conv2_bias_rv).sum() \
         + (bn2_weight_grad * bn2_weight_rv).sum() \
         + (bn2_bias_grad * bn2_bias_rv).sum() \
         + (fc_weight_grad * fc_weight_rv).sum() \
         + (fc_bias_grad * fc_bias_rv).sum()

if abs(jvp[0, 0] - true_jvp) < 1e-5:
    print('TEST 3 PASSED')
else:
    print('TEST 3 FAILED')
    print('ground truth: %.3f 	 result: %.3f' % (true_jvp, jvp[0, 0]))
    exit()

In [2]:
import torch
import torch.autograd as autograd
from models.modules import *
from models.resnetv2_rp import *
from models.resnet_sample_rp import *

In [2]:
# model = resnet18(grad_proj=True).cuda()

In [3]:
model = ResNet18(num_classes=50, grad_proj=True, device='cpu').to('cpu')

In [18]:
# time.sleep(3)
X = torch.randn(16, 100,device='cpu')
X.unsqueeze_(2).shape

torch.Size([16, 100, 1])

In [19]:
X.shape

torch.Size([16, 100, 1])

In [73]:
a = torch.randn(1, 100)
b = torch.randn(1, 100)

In [58]:
b.shape

torch.Size([16, 100])

In [76]:
a[0].shape

torch.Size([100])

In [66]:
es = torch.einsum('bi,bj->ibj', (a, b))
es.shape

torch.Size([100, 16, 100])

In [67]:
def kronecker(A, B):
    return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))

In [71]:
kronecker(a, a).shape

torch.Size([256, 10000])

In [5]:
a = torch.randn(8, 8, 100, 100)
b=  torch.randn(8, 8, 100, 100)

In [9]:
(a-b).sum()

tensor(1051.9144)

In [83]:
%timeit torch.zeros(64, 64, 100, 100).cuda()

116 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
torch.zeros(1).cuda()

tensor([0.], device='cuda:0')