In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os, time, shutil
from models import *

In [None]:
# Device Selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
torch.backends.cudnn.fastest = True   # maximize RTX 3080 throughput
print(f"=> Using device: {device}")

In [None]:
# Training Parameters
batch_size = 128
epochs = 200
lr = 0.01


In [None]:
best_acc = 0
save_dir = "result/VGG16_quant"
os.makedirs(save_dir, exist_ok=True)

In [None]:

model = VGG16_quant().to(device)
if torch.cuda.device_count() > 1:
    print(f"=> Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 40], gamma=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
scaler = torch.cuda.amp.GradScaler()

normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447],
                                 std=[0.247, 0.243, 0.262])

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
#                                            shuffle=True, num_workers=4, pin_memory=True)
# test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
#                                           shuffle=False, num_workers=4, pin_memory=True)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=256,              # 128 â†’ 256 (fits in 10 GB easily)
    shuffle=True,
    num_workers=os.cpu_count(),  # use all CPU cores
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,           # overlap data loading with compute
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=512,              # eval can use larger batch
    shuffle=False,
    num_workers=os.cpu_count(),
    pin_memory=True,
    persistent_workers=True,
)

class AverageMeter:
    def __init__(self): self.reset()
    def reset(self): self.val=self.avg=self.sum=self.count=0
    def update(self, val, n=1):
        self.val = val; self.sum += val*n; self.count += n; self.avg = self.sum/self.count

def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / target.size(0)))
    return res

def save_checkpoint(state, is_best, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_best.pth.tar'))


# def train(train_loader, model, criterion, optimizer, epoch):
#     model.train()
#     losses, top1 = AverageMeter(), AverageMeter()
#     start = time.time()

#     for i, (inputs, targets) in enumerate(train_loader):
#         inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

#         outputs = model(inputs)
#         loss = criterion(outputs, targets)

#         prec1 = accuracy(outputs, targets)[0]
#         losses.update(loss.item(), inputs.size(0))
#         top1.update(prec1.item(), inputs.size(0))

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if i % 100 == 0:
#             print(f"Epoch [{epoch}] [{i}/{len(train_loader)}] "
#                   f"Loss {losses.val:.4f} ({losses.avg:.4f})  "
#                   f"Acc {top1.val:.2f}% ({top1.avg:.2f}%)")

#     print(f" Epoch {epoch} done in {time.time()-start:.1f}s | Train Acc: {top1.avg:.2f}% | Loss: {losses.avg:.4f}")

def train(train_loader, model, criterion, optimizer, epoch):
    model.train()
    losses, top1 = AverageMeter(), AverageMeter()
    start = time.time()

    current_lr = optimizer.param_groups[0]['lr']

    for i, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        # accuracy in full precision is fine (small overhead)
        prec1 = accuracy(outputs, targets)[0]
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if i % 100 == 0:
            print(
                f"Epoch [{epoch}] [{i}/{len(train_loader)}] "
                f"LR {current_lr:.5e}  "
                f"Loss {losses.val:.4f} ({losses.avg:.4f})  "
                f"Acc {top1.val:.2f}% ({top1.avg:.2f}%)"
            )

    print(
        f" Epoch {epoch} done in {time.time()-start:.1f}s | "
        f"LR: {current_lr:.5e} | Train Acc: {top1.avg:.2f}% | Loss: {losses.avg:.4f}"
    )

def validate(val_loader, model, criterion, epoch):
    model.eval()
    losses, top1 = AverageMeter(), AverageMeter()
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            prec1 = accuracy(outputs, targets)[0]
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
    print(f"Validation Epoch {epoch}: Acc {top1.avg:.2f}% | Loss {losses.avg:.4f}")
    return top1.avg


In [None]:
# Training Loop
for epoch in range(1, epochs+1):
    train(train_loader, model, criterion, optimizer, epoch)
    val_acc = validate(test_loader, model, criterion, epoch)
    scheduler.step()

    is_best = val_acc > best_acc
    best_acc = max(val_acc, best_acc)

    save_checkpoint({
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, save_dir)

    print(f"Epoch {epoch} complete | Best Acc: {best_acc:.2f}%\n")

print("Training completed. Best accuracy: {:.2f}%".format(best_acc))


In [None]:
PATH = "result/VGG16_quant/model_best.pth.tar"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
device = torch.device("cuda") 

model.cuda()
model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
class SaveOutput:
    def __init__(self):
        self.outputs = []
    def __call__(self, module, module_in):
        self.outputs.append(module_in)
    def clear(self):
        self.outputs = []  
        
######### Save inputs from selected layer ##########
save_output = SaveOutput()
i = 0

for layer in model.modules():
    i = i+1
    if isinstance(layer, QuantConv2d):
        print(i,"-th layer prehooked")
        layer.register_forward_pre_hook(save_output)             
####################################################

dataiter = iter(test_loader)
images, labels = next(dataiter)
images = images.to(device)
out = model(images)

In [None]:
weight_q = model.features[3].weight_q
w_alpha = model.features[3].weight_quant.wgt_alpha
w_bit = 4

weight_int = weight_q / (w_alpha / (2**(w_bit-1)-1))
print(weight_int)

In [None]:
act = save_output.outputs[1][0]
act_alpha  = model.features[3].act_alpha
act_bit = 4
act_quant_fn = act_quantization(act_bit)

act_q = act_quant_fn(act, act_alpha)

act_int = act_q / (act_alpha / (2**act_bit-1))
print(act_int)

In [None]:
# Changed the code to mach the dimensions
conv_int = torch.nn.Conv2d(in_channels = 8, out_channels=8, kernel_size = 3, padding=1)
conv_int.weight = torch.nn.parameter.Parameter(weight_int)
conv_int.bias = model.features[3].bias
output_int = conv_int(act_int)
output_recovered = output_int * (act_alpha / (2**act_bit-1)) * (w_alpha / (2**(w_bit-1)-1))
print(output_recovered)

In [None]:
conv_ref = torch.nn.Conv2d(in_channels = 8, out_channels=8, kernel_size = 3, padding=1)
conv_ref.weight = model.features[3].weight_q
conv_ref.bias = model.features[3].bias
output_ref = conv_ref(act)
print(output_ref)

In [57]:
# act_int.size = torch.Size([128, 64, 32, 32])  <- batch_size, input_ch, ni, nj
a_int = act_int[0,:,:,:]  # pick only one input out of batch
# a_int.size() = [64, 32, 32]

# conv_int.weight.size() = torch.Size([64, 64, 3, 3])  <- output_ch, input_ch, ki, kj
w_int = torch.reshape(weight_int, (weight_int.size(0), weight_int.size(1), -1))  # merge ki, kj index to kij
# w_int.weight.size() = torch.Size([64, 64, 9])
                      
padding = 1
stride = 1
# He hard coded these values :(
#array_size = 64 # row and column number
array_size = w_int.size(0) 

nig = range(a_int.size(1))  ## ni group [0,1,...31]
njg = range(a_int.size(2))  ## nj group
 
icg = range(int(w_int.size(1)))  ## input channel [0,...63]
ocg = range(int(w_int.size(0)))  ## output channel


kijg = range(w_int.size(2)) # [0, .. 8]
ki_dim = int(math.sqrt(w_int.size(2)))  ## Kernel's 1 dim size

######## Padding before Convolution #######
a_pad = torch.zeros(len(icg), len(nig)+padding*2, len(njg)+padding*2).cuda()
# a_pad.size() = [64, 32+2pad, 32+2pad]
a_pad[ :, padding:padding+len(nig), padding:padding+len(njg)] = a_int.cuda()
a_pad = torch.reshape(a_pad, (a_pad.size(0), -1))  ## mergin ni and nj index into nij
# a_pad.size() = [64, (32+2pad)*(32+2pad)]

In [58]:
print(act_int.shape)
print(weight_int.shape)

torch.Size([512, 8, 32, 32])
torch.Size([8, 8, 3, 3])


In [59]:
###########################################

p_nijg = range(a_pad.size(1)) ## paded activation's nij group [0, ...34*34-1]

psum = torch.zeros( array_size, len(p_nijg), len(kijg)).cuda() 

for kij in kijg:       
    for nij in p_nijg:     # time domain, sequentially given input
        m = nn.Linear(array_size, array_size, bias=False)
        m.weight = torch.nn.Parameter(w_int[:,:,kij])
        psum[:, nij, kij] = m(a_pad[:,nij]).cuda()
 

In [60]:
import math

a_pad_ni_dim = int(math.sqrt(a_pad.size(1))) # 32 + 2*pad = 34

o_ni_dim = int((a_pad_ni_dim - (ki_dim- 1) - 1)/stride + 1) #34 - 2 - 1 + 1 = 32
o_nijg = range(o_ni_dim**2) # [0, 32*32-1]    
    
out = torch.zeros(len(ocg), len(o_nijg)).cuda()
  
   
### SFP accumulation ###
for o_nij in o_nijg: 
    for kij in kijg:  #[0, ... 8]
        out[:,o_nij] = out[:,o_nij] + \
        psum[:, int(o_nij/o_ni_dim)*a_pad_ni_dim + o_nij%o_ni_dim + int(kij/ki_dim)*a_pad_ni_dim + kij%ki_dim, kij]
                ## 2nd index = (int(o_nij/30)*32 + o_nij%30) + (int(kij/3)*32 + kij%3)

In [61]:
out_2D = torch.reshape(out, (out.size(0), o_ni_dim, -1)) # nij -> ni & nj
difference = (out_2D - output_int[0,:,:,:])
print(difference.abs().sum())

tensor(0.0031, device='cuda:0', grad_fn=<SumBackward0>)


In [62]:
output_int[0,:,:,:]

tensor([[[   0.,   -8.,    5.,  ...,   43.,   66.,   79.],
         [  26.,   -2.,    6.,  ...,    4.,   -7.,  -17.],
         [   2.,  -16.,  -13.,  ...,   -5.,  -17.,  -10.],
         ...,
         [ -34.,  -48.,  -39.,  ...,  -14.,  -32.,  -28.],
         [ -23.,  -26.,  -21.,  ...,  -12.,   -2.,   46.],
         [  13.,   25.,   37.,  ...,   71.,   41.,  -15.]],

        [[  20.,  -15.,    1.,  ...,    2.,    8.,   36.],
         [   4.,  -68.,  -54.,  ...,  -30.,  -25.,   23.],
         [  42.,  -21.,  -32.,  ...,  -16.,  -16.,   24.],
         ...,
         [  60.,  198.,  155.,  ...,  207.,  228.,  100.],
         [  80.,  214.,  197.,  ...,  237.,  234.,  147.],
         [  17.,   81.,   78.,  ...,  105.,   89.,   70.]],

        [[ -26.,   17.,  -16.,  ...,  -10.,   -4.,  142.],
         [ -52.,   -1.,  -11.,  ...,  -25.,  -44.,  166.],
         [ -31.,    5.,   -4.,  ...,  -14.,  -55.,  165.],
         ...,
         [ -99.,  -44.,   38.,  ..., -119.,    2.,  165.],
         [

In [None]:
######## Easier 2D version ########

import math

kig = range(int(math.sqrt(len(kijg))))
kjg = range(int(math.sqrt(len(kijg))))
    
o_nig = range(int((math.sqrt(len(nijg))+2*padding -(math.sqrt(len(kijg))- 1) - 1)/stride + 1))
o_njg = range(int((math.sqrt(len(nijg))+2*padding -(math.sqrt(len(kijg)) - 1) - 1)/stride + 1))
    
    
out = torch.zeros(len(ocg), len(o_nig), len(o_njg)).cuda()
  
   
### SFP accumulation ###
for ni in o_nig:
    for nj in o_njg:
        for ki in kig:
            for kj in kjg:
                for ic_tile in ic_tileg:    
                    for oc_tile in oc_tileg:   
                        out[oc_tile*array_size:(oc_tile+1)*array_size, ni, nj] = out[oc_tile*array_size:(oc_tile+1)*array_size, ni, nj] + \
                        psum[ic_tile, oc_tile, :, int(math.sqrt(len(nijg)))*(ni+ki) + (nj+kj), len(kig)*ki+kj]