In [1]:
import torch
from torch import optim
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torchvision.utils import make_grid
import torchvision.transforms as tt

from utils import CustomDataSet, load_data, print_config, DeviceDataLoader, supervised_samples, calc_mean_std, plotting, get_PATH, set_random_seed, CreateDataLoader
import config

from Classifier import Classifier
from Generator import Generator
from GANSSL import GANSSL, Discriminator

In [2]:
print_config()

RANDOM_SEED   :  11042004
DATA_DIR      :    ./data
USED_DATA     :    DOODLE
NUM_LABELLED  :        -1
DEVICE        :    cuda:0
GAN_BATCH_SIZE:       128


In [3]:
set_random_seed(config.RANDOM_SEED)

Setting seeds ...... 



In [4]:
if config.USED_DATA == "CIFAR10":
	mean = [0.5]*3
	std = [0.5]*3
	train_tfm = tt.Compose([
		tt.RandomCrop(32, padding=4, padding_mode='edge'),
		tt.RandomHorizontalFlip(),
		tt.Normalize(mean, std, inplace=True)
	])

else:
# if config.USED_DATA == "MNIST" or config.USED_DATA == "DOODLE":
	mean = [0.5]
	std = [0.5]
	train_tfm = tt.Compose([
		tt.Resize(32),
		tt.Normalize(mean, std, inplace=True)
	])

test_tfm = tt.Compose([
	tt.Resize(32),
	tt.Normalize(mean, std)
])

In [5]:
train_ds, test_ds, classes = load_data(train_tfm, test_tfm)

In [6]:
n_classes = len(classes)
channels = train_ds.x.shape[1]
n_classes, channels

(10, 1)

In [7]:
test_dl = CreateDataLoader(test_ds, batch_size=512, device=config.DEVICE)

In [8]:
N_Labelled = [50, 100, 500, 'full']
if config.USED_DATA == 'EMNIST':
	N_Labelled = [100, 200, 1000, 'full']
res = {}
for n in N_Labelled:
	res[n] = []
	name = "CNN"


	PATH = f'{config.USED_DATA}/{name}/_{n}'
	
	model = Classifier(channels, n_classes).to(config.DEVICE, non_blocking=True)
	model.load(PATH + ".pt")
	res[n].append(model.evaluate(test_dl))

	name = "GANSSL"

	PATH = f'{config.USED_DATA}/{name}/_{n}'

	if n != 'full':
		model = GANSSL(Generator(100, channels), Discriminator(channels, n_classes), 100, config.DEVICE)
		model.load_dis_state_dict(PATH + ".pt")
		res[n].append(model.evaluate(test_dl))

In [9]:
for n, r in res.items():
	print(f'{n}: {r}')

50: [0.3603000044822693, 0.6583]
100: [0.6449000239372253, 0.7211]
500: [0.8313999772071838, 0.8473]
full: [0.9455000162124634]
