In [1]:
!python3 -m ipykernel install --user --name=env

Installed kernelspec env in /h/u14/c2/00/saragihd/.local/share/jupyter/kernels/env


In [2]:
# Load save 
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import models
from models.quantization import quan_Conv2d, quan_Linear, quantize

In [3]:
# Load mnist
root = './data'
download = True
train_set = dset.MNIST(root=root, train=True, transform=None, download=download)
test_set = dset.MNIST(root=root, train=False, transform=None, download=download)

In [9]:
# Show data
def show(img):
    # Show image
    # Show with PIL
    plt.imshow(img, cmap='gray')
    print("Showing")
    plt.show()

# Get image
img, label = train_set[0]
print('Label:', label)
print('Image Size:', img.size)

# Convert to numpy
img = np.array(img)
print('Numpy Shape:', img.shape)

# Show image
show(img)

Label: 5
Image Size: (28, 28)
Numpy Shape: (28, 28)
Showing


In [10]:
# Add colour jitter, random affine, random horizontal flip, random rotation, gaussian blur, random erasing, and normalization
mean = (0.5,)
std = (0.5,)
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.RandomHorizontalFlip(p=0.8),
    transforms.RandomRotation(15),
    transforms.GaussianBlur(3),
    transforms.RandomErasing(p=0.9, scale=(0.02, 0.1)),
])

# Apply transform
img, label = train_set[0]
img = transform(img)
print('Transformed Shape:', img.size())

# Show image
img = img.permute(1, 2, 0)
img = img.numpy()
show(img)

Transformed Shape: torch.Size([1, 32, 32])
Showing


In [4]:
DATASET = 'finetune_mnist'
data_path = './data'
ARCH = "resnet32_quan"
chk_path = './save_finetune/cifar60/model_best_path.tar'
BATCH_SIZE = 128
if DATASET == 'finetune_mnist':
    mean = [0.5]
    std = [0.5]
    test_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
        transforms.Normalize(mean, std)
    ])
if DATASET == 'finetune_mnist':
    test_data = dset.MNIST(data_path,
                            train=False,
                            transform=test_transform,
                            download=True)
    num_classes = 10
    num_channels = 3

test_loader = torch.utils.data.DataLoader(test_data,
                                        batch_size=BATCH_SIZE,
                                        shuffle=True,
                                        num_workers=1,
                                        pin_memory=True)

In [None]:
net = models.__dict__[ARCH](num_classes, num_channels)

checkpoint = torch.load(chk_path)
state_tmp = net.state_dict()
if 'state_dict' in checkpoint.keys():
    state_tmp.update(checkpoint['state_dict'])
else:
    state_tmp.update(checkpoint)

#net.load_state_dict(state_tmp)
model_dict = net.state_dict()
pretrained_dict = {k:v for k, v in checkpoint['state_dict'].items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

In [None]:
output_branch = net(input)
num_branch = len(output_branch) # the number of branches
val_acc, _, val_los = validate(test_loader, net, criterion, log, num_branch, args.ic_only)