In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from easydict import EasyDict as edict

import sys
#sys.path.append('./src')

from src.dataset import get_loader
from src.trainer import Trainer
import src.distrib as distrib
from model.quant_model import resnet18_quant, resnet20_cifar, mv1_quant

In [2]:
args =edict()
args.quant = edict()


args['db'] = edict()
args.db.name = 'cifar10'
args.db.root = '/dataset/cifar10'

args.quant['arch'] = 'resnet20_quant' if args.db.name == 'cifar10' else 'resnet18_quant'
args.quant['QWeightFlag'] = True
args.quant['QActFlag'] = True
args.quant['bkwd_scaling_factorW'] = 1.0
args.quant['bkwd_scaling_factorA'] = 1.0
args.quant['groups']=4
args.quant.bit_list = ['2','4','6','8']

args.lr_sched = None
args.device = 'cuda:0'
args.epochs= 200
args.optim= 'sgd'
args.lr= 0.1
args.momentum= 0.9
args.w_decay= 5e-4
args.batch_size= 128
args.mixed= True  # if true, uses mixed precision training
args.beta2= 0.999
args.max_norm= 5
args.nesterov= True
args.alpha= 0.9
args.continue_from=False
args.checkpoint=False
args.history_file=None
args.pre_load_pretrained = True
args.restart= False # Ignore existing checkpoints
args.checkpoint_file= 'checkpoint.th'
args.history_file= 'history.json'
args.num_prints= 10

In [3]:
img_size = 32 if 'cifar' in args.db.name else 224 
trainset, testset, num_classes = get_loader(args, img_size)
criterion = nn.CrossEntropyLoss()
data = edict()
data['tr'] = distrib.loader(trainset, args.batch_size, shuffle=True, num_workers=4)
data['tt'] = distrib.loader(testset, args.batch_size, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
args.quant.arch = 'resnet20_quant'
initialize_model = False
if args.quant.arch == 'resnet20_quant':
    model = resnet20_cifar(args.quant)
    initialize_model = True
    from model.quant_modules import QConv
else:
    if args.quant.arch == 'resnet18_quant':
        model = resnet18_quant(args.quant)
        checkpoint_path = './r18-2468/checkpoint.pth.tar'
    elif args.quant.arch == 'mobilenetv1':
        model = mv1_quant(args.quant)
        checkpoint_path = './mv1-2468/model_best.pth.tar'
    
    checkpoint = torch.load(checkpoint_path)
    woddp_checkpoint = {}
    for key, value in checkpoint['state_dict'].items():
        woddp_checkpoint[key.replace('module.','')] = value
    #model.load_state_dict(woddp_checkpoint, strict=False)
    from model.quant_modules import QConv

model.cuda()

print("forward test")
if initialize_model:
    for layers in model.modules():
        if hasattr(layers, 'init'):
            layers.init.data.fill_(1)

inputs, label = next(iter(data['tr']))
inputs, label = inputs.cuda(), label.cuda()

with torch.no_grad():
    for bit in args.quant.bit_list:
        print("bit : ", bit)
        for name, layers in model.named_modules():
            if hasattr(layers, 'act_bit'):
                setattr(layers, "act_bit", int(bit))
            if hasattr(layers, 'weight_bit'):
                setattr(layers, "weight_bit", int(bit))    
        model(inputs)

forward test
bit :  2
before shape :  torch.Size([128, 256, 32, 32])
after shape :  torch.Size([128, 16, 32, 32])
out shape :  torch.Size([128, 16, 32, 32])
before shape :  torch.Size([128, 16, 32, 32])
after shape :  torch.Size([128, 16, 32, 32])
out shape :  torch.Size([128, 16, 32, 32])
before shape :  torch.Size([128, 16, 32, 32])
after shape :  torch.Size([128, 16, 32, 32])
out shape :  torch.Size([128, 16, 32, 32])
before shape :  torch.Size([128, 16, 32, 32])
after shape :  torch.Size([128, 16, 32, 32])
out shape :  torch.Size([128, 32, 16, 16])
before shape :  torch.Size([128, 32, 16, 16])
after shape :  torch.Size([128, 32, 16, 16])
out shape :  torch.Size([128, 32, 16, 16])
before shape :  torch.Size([128, 32, 16, 16])
after shape :  torch.Size([128, 32, 16, 16])
out shape :  torch.Size([128, 32, 16, 16])
before shape :  torch.Size([128, 32, 16, 16])
after shape :  torch.Size([128, 32, 16, 16])
out shape :  torch.Size([128, 64, 8, 8])
before shape :  torch.Size([128, 64, 8, 8

In [None]:
model

QResNet4Cifar(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): QBasicBlock4Cifar(
      (conv1): QConv(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False
        (uW): ParameterList(
            (0): Parameter containing: [torch.float32 of size  (cuda:0)]
            (1): Parameter containing: [torch.float32 of size  (cuda:0)]
            (2): Parameter containing: [torch.float32 of size  (cuda:0)]
            (3): Parameter containing: [torch.float32 of size  (cuda:0)]
        )
        (lW): ParameterList(
            (0): Parameter containing: [torch.float32 of size  (cuda:0)]
            (1): Parameter containing: [torch.float32 of size  (cuda:0)]
            (2): Parameter containing: [torch.float32 of size  (cuda:0)]
            (3): Parameter containing:

In [None]:
bit_list = [2, 4, 6, 8]
for n, m in model.named_modules():
    if isinstance(m, QConv):
        print(n)
        test_module = m
        
        first_weight_dict = {}
        for bit in [2, 8]:
            test_module.act_bit = bit
            test_module.weight_bit = bit
            print(bit)
            FWeight, FAct = test_module.select(test_module.weight, torch.randn(4 ,test_module.weight.shape[1], 32, 32))
            QWeight = test_module.group_weight_quantization(FWeight)
            
            first_channel = QWeight.shape[0] // (bit // 2)
            first_weight = QWeight[:first_channel]
            print("first weight shape :", first_weight.shape)
            first_weight_dict[f'{bit}_weight'] = first_weight

        print("diff : ", torch.nn.functional.mse_loss(first_weight_dict[f'2_weight'], first_weight_dict['8_weight']), 0 == torch.nn.functional.mse_loss(first_weight_dict[f'2_weight'], first_weight_dict['8_weight']))



layer1.0.conv1
2
first weight shape : torch.Size([16, 16, 3, 3])
8
first weight shape : torch.Size([4, 16, 3, 3])


  print("diff : ", torch.nn.functional.mse_loss(first_weight_dict[f'2_weight'], first_weight_dict['8_weight']), 0 == torch.nn.functional.mse_loss(first_weight_dict[f'2_weight'], first_weight_dict['8_weight']))


RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 0

In [None]:
FWeight,_ , FAct, _ = test_module.select(test_module.weight, torch.randn(4 ,test_module.weight.shape[1], 32, 32))

In [None]:
test_module.act_bit

In [None]:
test_module.uW[1]

In [None]:
QWeight = test_module.group_weight_quantization(FWeight)

In [None]:
single_length = QWeight.shape[0] // test_module.groups
first_weight = QWeight[:single_length, :, :, :]
for i in range(test_module.groups):
    print("diff : ", i, torch.nn.functional.mse_loss(first_weight, QWeight[single_length * i:single_length * (i+1), :, :, :]))



In [None]:
trainer = Trainer(data, model, criterion, None, args)

In [None]:
trainer.evaluate()

In [None]:
bit_info = edict()
bit_info['1'] = 2
bit_info['2'] = 4
bit_info['3'] = 6

sum([bit_info[f'{i}'] for i in range(1,4)])


In [None]:
test_weights = torch.randn(3, 5)
test1_tensor = nn.Linear(5, 3, bias=False)
test1_tensor.weight.data = test_weights
inputs = torch.randn(5)

test1_pc_tensor = nn.Linear(5, 3, bias=False)
test1_pc_tensor.weight.data = test_weights

def pc(z, tau=1e-4):
    if z.requires_grad:
        print('grad is true')
        z.register_hook(lambda grad, z = z.detach().clone(): grad + tau * z)
    return z

out = test1_tensor(inputs)
out2 = test1_pc_tensor(inputs)
out2_pc = pc(out2)


out.sum().backward(retain_graph=True)
out2_pc.sum().backward(retain_graph=True)

In [None]:
test1_tensor.weight.grad

In [None]:
test1_pc_tensor.weight.grad

In [None]:
import torch

In [None]:
s = torch.randn(16, 3, 32, 32)
torch.nn.AdaptiveAvgPool2d((1,1))(s).squeeze().shape

In [None]:
g = torch.randn(4, 10, 1)
f = torch.sigmoid(g)
ff = 1 - f 
q = (f > 0.5).float() * 0.9 + (ff >= 0.5).float() * 0.1
h = (f > 0.5).float() * 0.1 + (ff >= 0.5).float() * 0.9

print(q.shape)
t = torch.cat([q, h, q, h], dim=2)
print(t.shape)
print(t)
print(f.view(4, -1))
dist = torch.distributions.Categorical(t)
dist.sample()

In [None]:
dist.probs

In [None]:
g = torch.randn(4, 5)
print(g)
out = torch.nn.functional.log_softmax(g, dim=1)
print(out)
print(out.exp())
out.sum(dim=-1)