In [None]:
import os
import datetime
import argparse
import copy
import wandb
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data.distributed
import torch.utils.data

import utils.common as utils
from data import cifar10
from models.cifar10.vgg import vgg_16_bn
from models.cifar10.resnet import resnet_56
from decomposition.decomposition import decompose

In [None]:
train_loader, val_loader = cifar10.load_data('../data', 256)
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
cudnn.benchmark = True
cudnn.enabled = True

# load model
compress_rate = utils.get_cpr('[0.]*100')

In [None]:
model = eval('vgg_16_bn')(compress_rate=compress_rate).cuda()
ckpt = torch.load('checkpoint/cifar10/vgg_16_bn.pt', map_location='cuda:0')
model.load_state_dict(ckpt['state_dict'])

In [None]:
dcp_model = decompose(model, 1, 100, 1)

In [None]:
dummy = torch.randn(1, 3, 32, 32).cuda()

In [None]:
model.features.conv0(dummy).shape

In [None]:
dcp_model.features.conv0(dummy).shape

In [None]:
def validate(val_loader, model, criterion):
    losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')

    # switch to evaluation mode
    model.eval()
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            images = images.cuda()
            target = target.cuda()

            # compute output
            logits = model(images)
            loss = criterion(logits, target)

            # measure accuracy and record loss
            pred1, pred5 = utils.accuracy(logits, target, topk=(1, 5))
            n = images.size(0)
            losses.update(loss.item(), n)
            top1.update(pred1[0], n)
            top5.update(pred5[0], n)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                    .format(top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg


In [None]:
model

In [None]:
validate(val_loader, model, criterion)