In [1]:
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST

import matplotlib.pyplot as plt

from CNNmodel import ConvModel


In [2]:
random_seed = 0
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [3]:
DATA_DIR = "./data"
n_classes = 10
num_labelled = 100

In [4]:
def one_hot(y):
	one_hot_y = torch.zeros((len(y), 10))
	one_hot_y[np.arange(len(y)), y] = 1
	return one_hot_y

In [9]:
data = MNIST(DATA_DIR, train = True, download=True)
data_X = data.data.unsqueeze(1).float()
data_X /= 255
data_Y = one_hot(data.targets)

num_data = data_Y.numel()

In [11]:
def supervised_samples(X: Tensor, y: Tensor, n_samples, n_classes=10):
	number_per_class = n_samples//n_classes


	X_sup, y_sup = Tensor(), Tensor()


	for i in range(n_classes):
		X_with_class = X[torch.argmax(y, 1) == i]

		ix = torch.randint(0, len(X_with_class), [number_per_class])

		X_sup = torch.cat((X_sup, X_with_class[ix]))

		y_sup = torch.cat((y_sup, one_hot([i]*number_per_class)))

	return X_sup, y_sup

In [12]:
class CustomDataSet(Dataset):
	def __init__(self, x, y) -> None:
		self.x: torch.Tensor = x
		self.y: torch.Tensor = y
		self.n_samples = len(y)
	
	def __getitem__(self, index):
		return self.x[index], self.y[index]
	
	def __len__(self):
		return self.n_samples

In [14]:
class Generator(nn.Module):
	def __init__(self, inp_size, out_size) -> None:
		super().__init__()


		self.NN = nn.Sequential(
			nn.Linear(inp_size, 256*7*7),
			nn.LeakyReLU(negative_slope=0.2),
		)

		self.CONV = nn.Sequential(
			nn.ConvTranspose2d(256, 128, (3, 3), (2, 2)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(0.2),
			nn.ConvTranspose2d(128, 64, (3, 3), (1, 1)),
			nn.BatchNorm2d(64),
			nn.LeakyReLU(0.2),
			
		)

		self.out = nn.Sequential(
			nn.ConvTranspose2d(64, out_size[0], (3, 3), (2, 2)),
			nn.AdaptiveAvgPool2d((out_size[1], out_size[2])),
			nn.Sigmoid()
		)

		self.optimizer = optim.Adam(self.parameters(), lr = 0.0002, betas=[0.5, 0.999])
		self.criterion = nn.BCELoss()
	
	def forward(self, X: Tensor):
		X = self.NN(X)
		X = X.view(-1, 256, 7, 7)
		X = self.CONV(X)

		X = self.out(X)

		return X

In [16]:
class Feature_Extractor(nn.Module):
	def __init__(self, inp_channel) -> None:
		super().__init__()


		self.CNN = ConvModel(inp_channel)

		self.dropout = nn.Sequential(
			nn.Dropout(0.4)
		)

	
	def forward(self, X: Tensor):
		X = self.CNN(X)
		X = self.dropout(X)
		return X

In [17]:
class Classify(nn.Module):
	def __init__(self, feature_extractor: nn.Module, num_classes) -> None:
		super().__init__()

		self.CNN = feature_extractor

		self.out = nn.Sequential(
			nn.Linear(512, num_classes),
			nn.Softmax(1)
		)

		self.optimizer = optim.Adam(self.parameters(), lr = 0.0002, betas= [0.5, 0.999])

		self.criterion = nn.BCELoss()
	
		return torch.argmax(self.forward(x)).item()
	
	def forward(self, X: Tensor):
		X = self.CNN(X)
		X = self.out(X)

		return X

In [679]:
class Discriminator(nn.Module):
	def __init__(self, feature_extractor) -> None:
		super().__init__()
		self.CNN = feature_extractor

		self.out = nn.Sequential(
			nn.Linear(512, 1),
			nn.Sigmoid()
		)
	
		self.optimizer = optim.Adam(self.parameters(), lr=0.0002, betas=[0.5, 0.999])
		self.criterion = nn.BCELoss()

	def forward(self, X: Tensor):
		X = self.CNN(X)

		X = self.out(X)

		return X

In [680]:
class SGAN:
	def __init__(self, image_size, num_classes, feature_extractor: nn.Module, latent_size = 100, lr=0.0002):

		self.latent_size = latent_size

		CNN = feature_extractor

		self.generator = Generator(latent_size, image_size)

		self.classify = Classify(CNN, num_classes)
		self.discriminator = Discriminator(CNN)

		self.history = {}
	
	def __call__(self, X: torch.Tensor):
		return torch.argmax(self.classify(X))
	
	def validation(self, X: Tensor, y: Tensor):
		self.classify.eval()

		num_data = y.shape[0]

		run_size = 10000

		current = 0
		correct = 0
		
		while current < num_data:
			correct += torch.count_nonzero(torch.argmax(self.classify(X[current: current + run_size]), 1) == torch.argmax(y[current: current + run_size], 1))
			current += run_size

		return (correct.float() / num_data).item()


	def fit(self, X: Tensor, y: Tensor, epochs = 100, batch_size = 64):

		X_sup, y_sup = supervised_samples(X, y, num_labelled)

		bat_per_epo = int(X.shape[0]/batch_size)

		n_step = bat_per_epo * epochs

		half_batch = batch_size//2

		train_acc = 0

		for epoch in range(epochs):
			self.classify.train()
			self.discriminator.train()
			self.generator.train()

			# classify

			if train_acc <= 0.90:

				X_sup_real, y_sup_real = real_samples(X_sup, y_sup, half_batch)
				out_sup = self.classify(X_sup_real)
				sup_loss: Tensor = self.classify.criterion(out_sup, y_sup_real)

				self.classify.optimizer.zero_grad()
				sup_loss.backward()
				self.classify.optimizer.step()

			# discriminator real
			X_real, _ = real_samples(X, y, half_batch)

			out_real = self.discriminator(X_real)
			

			real_loss: Tensor = self.discriminator.criterion(out_real, torch.ones(half_batch, 1))

			self.discriminator.optimizer.zero_grad()
			real_loss.backward()
			self.classify.optimizer.step()

			z = torch.randn(half_batch, self.latent_size)

			# discriminator fake
			X_fake = self.generator(z)
			out_fake = self.discriminator(X_fake)
			y_fake = torch.zeros((half_batch, 1))
			fake_loss: Tensor = self.discriminator.criterion(out_fake, y_fake)
			self.discriminator.optimizer.zero_grad()
			fake_loss.backward()
			self.discriminator.optimizer.step()

			# generator
			z = torch.randn(batch_size, self.latent_size)
			X_gen = self.generator(z)
			out_gen = self.discriminator(X_gen)
			y_gen = torch.ones((batch_size, 1))
			gen_loss: Tensor = self.generator.criterion(out_gen, y_gen)

			self.generator.optimizer.zero_grad()
			gen_loss.backward()
			self.generator.optimizer.step()

			train_acc = self.validation(X_sup, y_sup)

			print(f"epoch: {epoch}, train acc: {train_acc*100:.2f}%, classification_loss: {sup_loss:.2f}, discrimination_loss: {(real_loss+fake_loss)/2:.2f}, generation_loss: {gen_loss:.2f}")



In [681]:
model = SGAN([1, 28, 28], n_classes, ConvModel(1))

In [682]:
model.fit(data_X, data_Y, epochs=500, batch_size=32)

epoch: 0, train acc: 10.00%, classification_loss: 0.34, discrimination_loss: 0.75, generation_loss: 0.67
epoch: 1, train acc: 10.00%, classification_loss: 0.34, discrimination_loss: 0.73, generation_loss: 0.59


epoch: 2, train acc: 11.00%, classification_loss: 0.30, discrimination_loss: 0.73, generation_loss: 0.68
epoch: 3, train acc: 14.00%, classification_loss: 0.30, discrimination_loss: 0.66, generation_loss: 0.68
epoch: 4, train acc: 18.00%, classification_loss: 0.31, discrimination_loss: 0.60, generation_loss: 0.71
epoch: 5, train acc: 21.00%, classification_loss: 0.28, discrimination_loss: 0.61, generation_loss: 0.73
epoch: 6, train acc: 31.00%, classification_loss: 0.30, discrimination_loss: 0.65, generation_loss: 0.63
epoch: 7, train acc: 30.00%, classification_loss: 0.29, discrimination_loss: 0.66, generation_loss: 0.56
epoch: 8, train acc: 34.00%, classification_loss: 0.28, discrimination_loss: 0.65, generation_loss: 0.65
epoch: 9, train acc: 41.00%, classification_loss: 0.28, discrimination_loss: 0.64, generation_loss: 0.70
epoch: 10, train acc: 42.00%, classification_loss: 0.27, discrimination_loss: 0.67, generation_loss: 0.70
epoch: 11, train acc: 48.00%, classification_loss: 0.2

In [683]:
model.validation(data_X, data_Y)

0.6123999953269958