In [3]:
import torch
import numpy as np
import torch.nn as nn
import math
from utils import *
from mobilenetv2 import *
import torch.optim as optim
from tqdm import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
savepath = os.path.join('checkpoints', 'MobileNetV2', 'sr')
if not os.path.exists(savepath):
    os.makedirs(savepath)
kwargs = {'num_workers': 4, 'pin_memory': True}

model = mobilenetv2_w1(pretrained=True).cuda()

#load dataset
test_loader = getTestData('imagenet',
                        batch_size=16,
                        path='F:\\imagenet\\',
                        for_inception=False)

calibration_loader = getSelfBuiltCalibrationData('imagenet',
                        batch_size=3,
                        path='F:\\imagenet\\',
                        for_inception=False)

train_loader = getTrainingData('imagenet',  
                        batch_size=3,
                        path='F:\\imagenet\\',
                        for_inception=False)

best_prec1 = -1
weight_decay=1e-4
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=weight_decay)

resume=0
epochs = 10
start_epoch=0
if resume:
    if os.path.isfile(resume):
        print("=> loading checkpoint '{}'".format(resume))
        checkpoint = torch.load(resume)
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(resume, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(resume))
        
def updateBN():
    s = 0.5e-5 #scale sparse rate (default: 0.0001)
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.weight.grad.data.add_(s * torch.sign(m.weight.data))
            
def train():
    sr = 1 #training with scale sparsity
    model.train()
    avg_loss = 0.
    train_acc = 0.
    #for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
    for batch_idx, (data, target) in tqdm(enumerate(calibration_loader), total=len(calibration_loader)):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        loss.backward()
        if sr:
            updateBN()
        optimizer.step()
        
def test(epoch,test_width=1.0,recal=False):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in tqdm(test_loader, total=len(test_loader)):
        data, target = data.cuda(), target.cuda()
        with torch.no_grad():
            output = model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).item()  # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nEpoch: {} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(epoch,
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return correct.item() / float(len(test_loader.dataset))

best_prec1 = 0.
scheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=epochs,eta_min=0)
start_epoch = 0
for epoch in range(start_epoch, epochs):
    train()
    prec1 = test(epoch=epoch)
    scheduler.step(epoch)
    lr_current = optimizer.param_groups[0]['lr']
    print("currnt lr:{}".format(lr_current))
    is_best = prec1 > best_prec1
    best_prec1 = max(prec1, best_prec1)
    if is_best:
        ckptfile = os.path.join(savepath, 'model_best.pth.tar')
    else:
        ckptfile = os.path.join(savepath, 'checkpoint.pth.tar')
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer': optimizer.state_dict(),
    }, ckptfile)

100%|██████████| 1000/1000 [00:24<00:00, 40.30it/s]
100%|██████████| 3125/3125 [01:32<00:00, 33.61it/s]



Epoch: 0 Test set: Average loss: 6.9240, Accuracy: 90/50000 (0.2%)

currnt lr:0.001


100%|██████████| 1000/1000 [00:24<00:00, 40.80it/s]
100%|██████████| 3125/3125 [01:32<00:00, 33.88it/s]



Epoch: 1 Test set: Average loss: 6.9283, Accuracy: 86/50000 (0.2%)

currnt lr:0.0009755282581475768


100%|██████████| 1000/1000 [00:24<00:00, 40.68it/s]
100%|██████████| 3125/3125 [01:31<00:00, 34.07it/s]



Epoch: 2 Test set: Average loss: 6.9089, Accuracy: 65/50000 (0.1%)

currnt lr:0.0009045084971874737


100%|██████████| 1000/1000 [00:24<00:00, 41.07it/s]
 56%|█████▌    | 1753/3125 [00:52<00:41, 33.10it/s]


KeyboardInterrupt: 