In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from easydict import EasyDict as edict

import sys
#sys.path.append('./src')

from src.dataset import get_loader
from src.trainer import Trainer
import src.distrib as distrib
from cifar_model.quant_resnet4cifar_my import resnet8_quant, resnet20_quant, QBasicBlock4Cifar
from cifar_model.quant_conv import QConv

from imagenet_model.quant_resnet_my import resnet18_quant
from imagenet_model.quant_mv1_my import Model as MobileNetV1

In [2]:
args =edict()
args.quant = edict()


args['db'] = edict()
args.db.name = 'imagenet'
args.db.root = '/dataset/imagenet'

args.quant['arch'] = 'resnet20_quant' if args.db.name == 'cifar10' else 'resnet18_quant'
args.quant['QWeightFlag'] = True
args.quant['QActFlag'] = True
args.quant['bkwd_scaling_factorW'] = 1.0
args.quant['bkwd_scaling_factorA'] = 1.0
args.quant['groups']=4
args.quant.bit_list = ['2','4','6','8']

args.lr_sched = None
args.device = 'cuda:0'
args.epochs= 200
args.optim= 'sgd'
args.lr= 0.1
args.momentum= 0.9
args.w_decay= 5e-4
args.batch_size= 128
args.mixed= True  # if true, uses mixed precision training
args.beta2= 0.999
args.max_norm= 5
args.nesterov= True
args.alpha= 0.9
args.continue_from=False
args.checkpoint=False
args.history_file=None
args.pre_load_pretrained = True
args.restart= False # Ignore existing checkpoints
args.checkpoint_file= 'checkpoint.th'
args.history_file= 'history.json'
args.num_prints= 10

In [3]:
img_size = 32 if 'cifar' in args.db.name else 224 
trainset, testset, num_classes = get_loader(args, img_size)
criterion = nn.CrossEntropyLoss()
data = edict()
data['tr'] = distrib.loader(trainset, args.batch_size, shuffle=True, num_workers=4)
data['tt'] = distrib.loader(testset, args.batch_size, shuffle=False, num_workers=4)

In [4]:
args.quant.arch = 'mobilenetv1'

if args.quant.arch == 'resnet20_quant':
    model = resnet20_quant(args.quant)
else:
    if args.quant.arch == 'resnet18_quant':
        model = resnet18_quant(args.quant)
        checkpoint_path = './r18-2468/checkpoint.pth.tar'
    elif args.quant.arch == 'mobilenetv1':
        model = MobileNetV1(args.quant)
        checkpoint_path = './mv1-2468/model_best.pth.tar'
    
    checkpoint = torch.load(checkpoint_path)
    woddp_checkpoint = {}
    for key, value in checkpoint['state_dict'].items():
        woddp_checkpoint[key.replace('module.','')] = value
    model.load_state_dict(woddp_checkpoint)
model.cuda()

Model(
  (head): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (stage_0_layer_0): QDepthwiseSeparableConv(
    (body): Sequential(
      (0): QConv_Tra_Mulit(
        128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False
        (uW): ParameterList(
            (0): Parameter containing: [torch.float32 of size  (cuda:0)]
            (1): Parameter containing: [torch.float32 of size  (cuda:0)]
            (2): Parameter containing: [torch.float32 of size  (cuda:0)]
            (3): Parameter containing: [torch.float32 of size  (cuda:0)]
        )
        (lW): ParameterList(
            (0): Parameter containing: [torch.float32 of size  (cuda:0)]
            (1): Parameter containing: [torch.float32 of size  (cuda:0)]
            (2): Parameter containing: [torch.float32 of size  (cuda

In [5]:
trainer = Trainer(data, model, criterion, None, args)

In [6]:
trainer.evaluate()

{'2_loss': 1.9976301193237305, '2_Acc1': 54.5359992980957, '2_Acc5': 78.25399780273438, '4_loss': 1.2348692417144775, '4_Acc1': 70.24800109863281, '4_Acc5': 89.27799987792969, '6_loss': 1.1474791765213013, '6_Acc1': 72.08799743652344, '6_Acc5': 90.32599639892578, '8_loss': 1.2058947086334229, '8_Acc1': 70.63999938964844, '8_Acc5': 89.4280014038086, 'avg_loss': 1.396468311548233, 'avg_Acc1': 66.8779993057251, 'avg_Acc5': 86.82149887084961}


In [7]:
bit_info = edict()
bit_info['1'] = 2
bit_info['2'] = 4
bit_info['3'] = 6

sum([bit_info[f'{i}'] for i in range(1,4)])


12

In [10]:
test_weights = torch.randn(3, 5)
test1_tensor = nn.Linear(5, 3, bias=False)
test1_tensor.weight.data = test_weights
inputs = torch.randn(5)

test1_pc_tensor = nn.Linear(5, 3, bias=False)
test1_pc_tensor.weight.data = test_weights

def pc(z, tau=1e-4):
    if z.requires_grad:
        print('grad is true')
        z.register_hook(lambda grad, z = z.detach().clone(): grad + tau * z)
    return z

out = test1_tensor(inputs)
out2 = test1_pc_tensor(inputs)
out2_pc = pc(out2)


out.sum().backward(retain_graph=True)
out2_pc.sum().backward(retain_graph=True)

grad is true


In [11]:
test1_tensor.weight.grad

tensor([[-0.7156, -0.4375, -1.9236,  1.3520, -0.5358],
        [-0.7156, -0.4375, -1.9236,  1.3520, -0.5358],
        [-0.7156, -0.4375, -1.9236,  1.3520, -0.5358]])

In [12]:
test1_pc_tensor.weight.grad

tensor([[-0.7158, -0.4376, -1.9240,  1.3524, -0.5359],
        [-0.7156, -0.4375, -1.9236,  1.3521, -0.5358],
        [-0.7157, -0.4375, -1.9237,  1.3521, -0.5358]])

In [1]:
import torch

In [3]:
s = torch.randn(16, 3, 32, 32)
torch.nn.AdaptiveAvgPool2d((1,1))(s).squeeze().shape

torch.Size([16, 3])

In [17]:
g = torch.randn(4, 10, 1)
f = torch.sigmoid(g)
ff = 1 - f 
q = (f > 0.5).float() * 0.9 + (ff >= 0.5).float() * 0.1
h = (f > 0.5).float() * 0.1 + (ff >= 0.5).float() * 0.9

print(q.shape)
t = torch.cat([q, h, q, h], dim=2)
print(t.shape)
print(t)
print(f.view(4, -1))
dist = torch.distributions.Categorical(t)
dist.sample()

torch.Size([4, 10, 1])
torch.Size([4, 10, 4])
tensor([[[0.1000, 0.9000, 0.1000, 0.9000],
         [0.1000, 0.9000, 0.1000, 0.9000],
         [0.1000, 0.9000, 0.1000, 0.9000],
         [0.9000, 0.1000, 0.9000, 0.1000],
         [0.1000, 0.9000, 0.1000, 0.9000],
         [0.9000, 0.1000, 0.9000, 0.1000],
         [0.9000, 0.1000, 0.9000, 0.1000],
         [0.1000, 0.9000, 0.1000, 0.9000],
         [0.9000, 0.1000, 0.9000, 0.1000],
         [0.1000, 0.9000, 0.1000, 0.9000]],

        [[0.9000, 0.1000, 0.9000, 0.1000],
         [0.1000, 0.9000, 0.1000, 0.9000],
         [0.9000, 0.1000, 0.9000, 0.1000],
         [0.9000, 0.1000, 0.9000, 0.1000],
         [0.1000, 0.9000, 0.1000, 0.9000],
         [0.1000, 0.9000, 0.1000, 0.9000],
         [0.1000, 0.9000, 0.1000, 0.9000],
         [0.1000, 0.9000, 0.1000, 0.9000],
         [0.9000, 0.1000, 0.9000, 0.1000],
         [0.9000, 0.1000, 0.9000, 0.1000]],

        [[0.9000, 0.1000, 0.9000, 0.1000],
         [0.9000, 0.1000, 0.9000, 0.1000],
    

tensor([[1, 3, 3, 0, 3, 3, 0, 3, 2, 1],
        [2, 0, 2, 2, 3, 3, 3, 3, 2, 0],
        [1, 0, 2, 3, 1, 2, 3, 3, 1, 2],
        [2, 2, 3, 1, 2, 2, 2, 3, 1, 0]])

In [19]:
dist.probs

tensor([[[0.0500, 0.4500, 0.0500, 0.4500],
         [0.0500, 0.4500, 0.0500, 0.4500],
         [0.0500, 0.4500, 0.0500, 0.4500],
         [0.4500, 0.0500, 0.4500, 0.0500],
         [0.0500, 0.4500, 0.0500, 0.4500],
         [0.4500, 0.0500, 0.4500, 0.0500],
         [0.4500, 0.0500, 0.4500, 0.0500],
         [0.0500, 0.4500, 0.0500, 0.4500],
         [0.4500, 0.0500, 0.4500, 0.0500],
         [0.0500, 0.4500, 0.0500, 0.4500]],

        [[0.4500, 0.0500, 0.4500, 0.0500],
         [0.0500, 0.4500, 0.0500, 0.4500],
         [0.4500, 0.0500, 0.4500, 0.0500],
         [0.4500, 0.0500, 0.4500, 0.0500],
         [0.0500, 0.4500, 0.0500, 0.4500],
         [0.0500, 0.4500, 0.0500, 0.4500],
         [0.0500, 0.4500, 0.0500, 0.4500],
         [0.0500, 0.4500, 0.0500, 0.4500],
         [0.4500, 0.0500, 0.4500, 0.0500],
         [0.4500, 0.0500, 0.4500, 0.0500]],

        [[0.4500, 0.0500, 0.4500, 0.0500],
         [0.4500, 0.0500, 0.4500, 0.0500],
         [0.4500, 0.0500, 0.4500, 0.0500],
       

In [25]:
g = torch.randn(4, 5)
print(g)
out = torch.nn.functional.log_softmax(g, dim=1)
print(out)
print(out.exp())
out.sum(dim=-1)

tensor([[ 0.5604, -0.5688,  0.1439,  0.4462, -1.3286],
        [ 0.6300, -0.5767,  0.6332, -0.9580, -0.0075],
        [-0.2042, -1.4047, -0.9545, -0.9557, -0.0792],
        [-0.9626, -0.3373,  0.7658,  0.5411,  0.8516]])
tensor([[-1.1072, -2.2364, -1.5237, -1.2214, -2.9962],
        [-1.1103, -2.3170, -1.1071, -2.6984, -1.7479],
        [-1.2173, -2.4179, -1.9677, -1.9688, -1.0923],
        [-2.9515, -2.3262, -1.2231, -1.4479, -1.1373]])
tensor([[0.3305, 0.1068, 0.2179, 0.2948, 0.0500],
        [0.3295, 0.0986, 0.3305, 0.0673, 0.1741],
        [0.2960, 0.0891, 0.1398, 0.1396, 0.3355],
        [0.0523, 0.0977, 0.2943, 0.2351, 0.3207]])


tensor([-9.0850, -8.9807, -8.6639, -9.0860])