In [1]:
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

import matplotlib.pyplot as plt


from CNNmodel import ConvModel
from generator import Generator
from utils import CustomDataSet, supervised_samples, load_data, print_config, ceil


import config

In [2]:
print_config(config)

RANDOM_SEED :   110404
DATA_DIR    :   ./data
USED_DATA   :  CIFAR10
LEAKY       :    False
NOSIE       :    False
NUM_LABELLED:    50000
DEVICE      :   cuda:0
EPOCHS      :      100
BATCH_SIZE  :       32


In [3]:
torch.manual_seed(config.RANDOM_SEED)
np.random.seed(config.RANDOM_SEED)

In [4]:
X_train, y_train, X_test, y_test, classes = load_data(-1, 1)

KeyboardInterrupt: 

In [None]:
channel = X_train.shape[1]
n_classes = len(classes)

In [None]:
X_sup, y_sup = supervised_samples(X_train, y_train, config.NUM_LABELLED, n_classes)

In [None]:
class Classifier(nn.Module):
	def __init__(self, CNNlayer: nn.Module, num_classes) -> None:
		super().__init__()

		self.CNN = CNNlayer

		self.out = nn.Sequential(
			nn.Linear(512, num_classes),
			nn.Softmax(1)
		)

		self.optimizer = optim.Adam(self.parameters(), lr = 0.0002, betas= [0.5, 0.999])

		self.criterion = nn.BCELoss()
	
	
	def forward(self, X: Tensor):
		X = self.CNN(X)
		X = self.out(X)

		return X

In [None]:
class Discriminator(nn.Module):
	def __init__(self, CNNlayer) -> None:
		super().__init__()
		self.CNN = CNNlayer

		self.out = nn.Sequential(
			nn.Linear(512, 1),
			nn.Sigmoid()
		)
	
		self.optimizer = optim.Adam(self.parameters(), lr=0.0002, betas=[0.5, 0.999])
		self.criterion = nn.BCELoss()

	def forward(self, X: Tensor):
		X = self.CNN(X)

		X = self.out(X)

		return X

In [None]:
class TripleGAN(nn.Module):
	def __init__(self, image_size, num_classes, CNNlayer: nn.Module, latent_size = 100, lr=0.0002):

		super().__init__()

		self.latent_size = latent_size

		self.n_classes = num_classes

		self.CNN = CNNlayer

		self.generator = Generator(latent_size, image_size)

		self.classify = Classifier(CNNlayer, num_classes)
		self.discriminator = Discriminator(CNNlayer)

		self.name = "TripleGAN"

		self.history = {}

	def forward(self, X: Tensor):
		return self.classify(X)
	
	def save(self, PATH = "./"):
		torch.save(self.state_dict(), PATH)

	def load(self, PATH):
		self.load_state_dict(torch.load(PATH))
	
	def validation(self, X: Tensor, y: Tensor):
		self.classify.eval()

		num_data = y.shape[0]

		run_size = 10000

		current = 0
		correct = 0
		
		while current < num_data:
			correct += torch.count_nonzero(torch.argmax(self.classify(X[current: current + run_size]), 1) == torch.argmax(y[current: current + run_size], 1))
			current += run_size

		return (float(correct.item()) / float(num_data))
	
	def plot(self):
		if len(self.history) == 0:
			return
		n = self.history['epochs']

		x = np.arange(n)

		for name, value in self.history.items():
			if name == 'epochs':
				continue
			plt.plot(x, value, label = name)

		
		plt.xlabel('epoch')
		plt.ylabel('acc')

		plt.legend()
		plt.show()



	def training_step(self, model: nn.Module, optimizer: optim.Optimizer, criterion: nn.modules.loss._Loss, X: Tensor, y: Tensor):
		out: Tensor = model(X)

		loss: Tensor = criterion(out, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		return loss.item()
	

	def fit(self, X: Tensor, y: Tensor, X_sup, y_sup, epochs = 100, batch_size = 64, save_best = True, PATH = "./", validation_data: list[torch.Tensor] = None):
		self.history['epochs'] = epochs
		self.history['train'] = []

		if validation_data:
			X_val, y_val = validation_data
			self.history['validation'] = []

		datasets = CustomDataSet(X, y)

		num_data = len(datasets)

		sup_datasets = CustomDataSet(X_sup, y_sup)

		dataloader = DataLoader(datasets, batch_size=batch_size, shuffle=True)

		sup_dataloader = DataLoader(sup_datasets, batch_size=batch_size//2, shuffle=True)

		dataloader.batch_size

		best_acc = 0

		for epoch in range(epochs):
			self.classify.train()
			self.discriminator.train()
			self.generator.train()

			print(f"epoch: {epoch}\nclassify: ")
			
			# for classify
			sup_loss = 0
			for inputs, labels in tqdm(sup_dataloader):
				sup_loss += self.training_step(self.classify, self.classify.optimizer, self.classify.criterion, inputs.to(config.DEVICE), labels.to(config.DEVICE))

			
			sup_loss /= ceil(num_data/(batch_size//2))

			print(f'GAN:')
			# for discriminator and generator
			real_loss = 0
			fake_loss = 0
			gen_loss = 0
			for inputs, _ in tqdm(dataloader):
				real_loss += self.training_step(self.discriminator, self.discriminator.optimizer, self.discriminator.criterion, inputs, torch.ones((inputs.shape[0], 1)).to(config.DEVICE))
				z = torch.randn((inputs.shape[0], self.latent_size)).to(config.DEVICE)
				
				gen_out = self.generator(z)
				fake_loss += self.training_step(self.discriminator, self.discriminator.optimizer, self.discriminator.criterion, gen_out, torch.zeros(inputs.shape[0], 1).to(config.DEVICE))

				gen_out = self.generator(z)
				gen_loss += self.training_step(self.discriminator, self.generator.optimizer, self.discriminator.criterion, gen_out, torch.ones((inputs.shape[0], 1)).to(config.DEVICE))

			real_loss /= ceil(num_data/batch_size)
			fake_loss /= ceil(num_data/batch_size)
			gen_loss /= ceil(num_data/batch_size)

			train_acc = self.validation(X_sup, y_sup)

			self.history['train'].append(train_acc)

			print(f"train acc: {train_acc*100:.2f}%", end = "")

			if validation_data:
				val_acc = self.validation(X_val, y_val)

				self.history['validation'].append(val_acc)

				print(f', val acc: {val_acc*100:.2f}%', end = "")

				if val_acc >= best_acc:
					best_acc = val_acc

					if save_best:
						self.save(PATH + self.name + ".pt")
			
			print(f", classification_loss: {sup_loss:.2f}, discrimination_loss: {(real_loss+fake_loss)/2:.2f}, generation_loss: {gen_loss:.2f}")


In [None]:
model = TripleGAN(X_train.shape[1:], n_classes, ConvModel(channel, config.LEAKY)).to(config.DEVICE)

In [None]:
model.fit(X_train, y_train, X_sup, y_sup, epochs=config.EPOCHS, batch_size=config.BATCH_SIZE, PATH=config.USED_DATA + "/", validation_data=(X_test, y_test))

epoch: 0
classify: 


100%|██████████| 32/32 [00:00<00:00, 84.39it/s]


GAN:


100%|██████████| 782/782 [00:12<00:00, 61.27it/s]


train acc: 13.10%, val acc: 13.31%, classification_loss: 0.01, discrimination_loss: 0.60, generation_loss: 0.98
epoch: 1
classify: 


100%|██████████| 32/32 [00:00<00:00, 439.41it/s]


GAN:


100%|██████████| 782/782 [00:12<00:00, 61.33it/s]


train acc: 16.80%, val acc: 15.60%, classification_loss: 0.01, discrimination_loss: 0.59, generation_loss: 1.00
epoch: 2
classify: 


100%|██████████| 32/32 [00:00<00:00, 446.02it/s]


GAN:


  6%|▋         | 50/782 [00:00<00:12, 59.42it/s]


KeyboardInterrupt: 

In [None]:
model.plot()

In [None]:
model.load(config.USED_DATA + "/TripleGAN.pt")

In [None]:
model.validation(X_test, y_test)

0.3404