In [2]:
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.backends.cudnn as cudnn

from google.colab import drive
drive.mount("/content/drive", force_remount=True)
drive_path = "drive/My Drive" # do not modify this line
relative_path = "/practice"   # set to your relative path 

base_path = drive_path + relative_path

import sys
sys.path.append(base_path)

from train_test import train, test

LR = 0.01
EPOCH = 5

# Data
print('==> Preparing data')
from dataset import cifar10_dataset
trainloader, testloader = cifar10_dataset(base_path + "/data")

# Model
print('==> Building model')
from resnet_quant import ResNet18
model = ResNet18()
model.load_state_dict(torch.load(base_path + "/train_best.pth"))

if torch.cuda.is_available():
    model.cuda()
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True
criterion = torch.nn.CrossEntropyLoss()

print('==> Full-precision model accuracy')
from quant_op import Q_ReLU, Q_Conv2d, Q_Linear
test(model, testloader, criterion)

for name, module in model.named_modules():
    if isinstance(module, Q_ReLU):
        module.n_lv = 8
        module.bound = 1
    
    if isinstance(module, (Q_Conv2d, Q_Linear)):
        module.n_lv = 8
        module.ratio = 0.5

print('==> Quantized model accuracy')
from quant_op import Q_ReLU, Q_Conv2d, Q_Linear
test(model, testloader, criterion)

best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCH, last_epoch=start_epoch-1)

for epoch in range(start_epoch, start_epoch + EPOCH):
    scheduler.step()
    train(model, trainloader, criterion, optimizer, epoch)
    acc = test(model, testloader, criterion)

    if acc > best_acc:
        best_acc = acc
        torch.save(model.module.state_dict(), base_path +  "quant_best.pth")

print('==> Fine-tuned model accuracy')
from quant_op import Q_ReLU, Q_Conv2d, Q_Linear
test(model, testloader, criterion)

Mounted at /content/drive
==> Preparing data
Files already downloaded and verified
Files already downloaded and verified
==> Building model
==> Full-precision model accuracy
Test: [10/100]	Time 0.063 (0.111)	Loss 0.1693 (0.2120)	Prec@1 95.000 (94.300)
Test: [20/100]	Time 0.060 (0.085)	Loss 0.3840 (0.2106)	Prec@1 93.000 (94.650)
Test: [30/100]	Time 0.060 (0.077)	Loss 0.1171 (0.2190)	Prec@1 96.000 (94.667)
Test: [40/100]	Time 0.060 (0.073)	Loss 0.1930 (0.2174)	Prec@1 97.000 (94.750)
Test: [50/100]	Time 0.061 (0.070)	Loss 0.3337 (0.2224)	Prec@1 94.000 (94.660)
Test: [60/100]	Time 0.064 (0.069)	Loss 0.1739 (0.2162)	Prec@1 96.000 (94.817)
Test: [70/100]	Time 0.060 (0.068)	Loss 0.1721 (0.2082)	Prec@1 96.000 (95.014)
Test: [80/100]	Time 0.060 (0.067)	Loss 0.0803 (0.2097)	Prec@1 96.000 (95.013)
Test: [90/100]	Time 0.059 (0.066)	Loss 0.2229 (0.2079)	Prec@1 94.000 (94.989)
Test: [100/100]	Time 0.060 (0.065)	Loss 0.0745 (0.2060)	Prec@1 97.000 (94.980)
 * Prec@1 94.980
==> Quantized model accuracy



Epoch 0: [10/196]	Time 0.491 (0.532)	Data 0.001 (0.067)	Loss 0.0612 (0.0557)	Prec@1 98.047 (98.398)
Epoch 0: [20/196]	Time 0.489 (0.508)	Data 0.002 (0.034)	Loss 0.0799 (0.0504)	Prec@1 97.656 (98.359)
Epoch 0: [30/196]	Time 0.488 (0.501)	Data 0.001 (0.023)	Loss 0.1054 (0.0543)	Prec@1 96.484 (98.333)
Epoch 0: [40/196]	Time 0.491 (0.497)	Data 0.002 (0.018)	Loss 0.0430 (0.0508)	Prec@1 98.438 (98.428)
Epoch 0: [50/196]	Time 0.492 (0.495)	Data 0.002 (0.015)	Loss 0.0709 (0.0502)	Prec@1 98.047 (98.438)
Epoch 0: [60/196]	Time 0.485 (0.493)	Data 0.002 (0.012)	Loss 0.0361 (0.0506)	Prec@1 99.219 (98.405)
Epoch 0: [70/196]	Time 0.486 (0.492)	Data 0.002 (0.011)	Loss 0.0346 (0.0492)	Prec@1 99.219 (98.449)
Epoch 0: [80/196]	Time 0.489 (0.491)	Data 0.001 (0.010)	Loss 0.0596 (0.0487)	Prec@1 98.047 (98.447)
Epoch 0: [90/196]	Time 0.489 (0.491)	Data 0.001 (0.009)	Loss 0.0604 (0.0489)	Prec@1 97.656 (98.416)
Epoch 0: [100/196]	Time 0.489 (0.490)	Data 0.002 (0.008)	Loss 0.0721 (0.0503)	Prec@1 97.656 (98.359)

93.64