In [256]:
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
import random
from tqdm import tqdm

import matplotlib.pyplot as plt

from CNNmodel import ConvModel

In [257]:
random_seed = 0
torch.manual_seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)

In [258]:
DATA_DIR = "./data"
n_classes = 10
num_labelled = 125
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [259]:
def one_hot(y):
	one_hot_y = torch.zeros((len(y), 10))
	one_hot_y[np.arange(len(y)), y] = 1
	return one_hot_y

In [260]:
data = MNIST(DATA_DIR, train = True, download=True)
data_X = data.data.unsqueeze(1).float().to(device)
data_X /= 255
data_Y = one_hot(data.targets).to(device)

num_data = data_Y.shape[0]

In [261]:
def supervised_samples(X: Tensor, y: Tensor, n_samples, val_ratio = 0.2):
	num_data = y.shape[0]

	ix = np.random.randint(0, num_data, n_samples)

	val = ix[:int(n_samples*val_ratio)]
	sup = ix[int(n_samples*val_ratio):]

	X_sup = X[sup]
	y_sup = y[sup]
	X_val = X[val]
	y_val = y[val]

	return X_sup, y_sup, X_val, y_val

In [262]:
class CustomDataSet(Dataset):
	def __init__(self, x: Tensor, y: Tensor) -> None:
		self.x: Tensor = x
		self.y: Tensor = y
		self.n_samples = len(y)
	
	def __getitem__(self, index):
		return self.x[index], self.y[index]
	
	def __len__(self):
		return self.n_samples

In [263]:
class Generator(nn.Module):
	def __init__(self, inp_size, out_size) -> None:
		super().__init__()


		self.NN = nn.Sequential(
			nn.Linear(inp_size, 256*7*7),
			nn.LeakyReLU(negative_slope=0.2),
		)

		self.CONV = nn.Sequential(
			nn.ConvTranspose2d(256, 128, (3, 3), (2, 2)),
			nn.BatchNorm2d(128),
			nn.LeakyReLU(0.2),
			nn.ConvTranspose2d(128, 64, (3, 3), (1, 1)),
			nn.BatchNorm2d(64),
			nn.LeakyReLU(0.2),
			
		)

		self.out = nn.Sequential(
			nn.ConvTranspose2d(64, out_size[0], (3, 3), (2, 2)),
			nn.AdaptiveAvgPool2d((out_size[1], out_size[2])),
			nn.Sigmoid()
		)

		self.optimizer = optim.Adam(self.parameters(), lr = 0.0002, betas=[0.5, 0.999])
		self.criterion = nn.BCELoss()
	
	def forward(self, X: Tensor):
		X = self.NN(X)
		X = X.view(-1, 256, 7, 7)
		X = self.CONV(X)

		X = self.out(X)

		return X

In [264]:
class Feature_Extractor(nn.Module):
	def __init__(self, inp_channel) -> None:
		super().__init__()


		self.CNN = ConvModel(inp_channel)

		self.dropout = nn.Sequential(
			nn.Dropout(0.4)
		)

	
	def forward(self, X: Tensor):
		X = self.CNN(X)
		X = self.dropout(X)
		return X

In [265]:
class Classify(nn.Module):
	def __init__(self, feature_extractor: nn.Module, num_classes) -> None:
		super().__init__()

		self.CNN = feature_extractor

		self.out = nn.Sequential(
			nn.Linear(512, num_classes),
			nn.Softmax(1)
		)

		self.optimizer = optim.Adam(self.parameters(), lr = 0.0002, betas= [0.5, 0.999])

		self.criterion = nn.BCELoss()
	
	
	def forward(self, X: Tensor):
		X = self.CNN(X)
		X = self.out(X)

		return X

In [266]:
class Discriminator(nn.Module):
	def __init__(self, feature_extractor) -> None:
		super().__init__()
		self.CNN = feature_extractor

		self.out = nn.Sequential(
			nn.Linear(512, 1),
			nn.Sigmoid()
		)
	
		self.optimizer = optim.Adam(self.parameters(), lr=0.0002, betas=[0.5, 0.999])
		self.criterion = nn.BCELoss()

	def forward(self, X: Tensor):
		X = self.CNN(X)

		X = self.out(X)

		return X

In [267]:
class SGAN:
	def __init__(self, image_size, num_classes, feature_extractor: nn.Module, latent_size = 100, lr=0.0002):

		self.latent_size = latent_size

		CNN = feature_extractor.to(device)

		self.generator = Generator(latent_size, image_size).to(device)

		self.classify = Classify(CNN, num_classes).to(device)
		self.discriminator = Discriminator(CNN).to(device)

		self.history = {}
	
	def __call__(self, X: torch.Tensor):
		return torch.argmax(self.classify(X))
	
	def save(self, PATH = "./"):
		torch.save(self.classify.state_dict(), PATH + "/classify.pt")
		torch.save(self.discriminator.state_dict(), PATH + "/discriminator.pt")
		torch.save(self.generator.state_dict(), PATH + "/generator.pt")
	
	def validation(self, X: Tensor, y: Tensor):
		self.classify.eval()

		num_data = y.shape[0]

		run_size = 10000

		current = 0
		correct = 0
		
		while current < num_data:
			correct += torch.count_nonzero(torch.argmax(self.classify(X[current: current + run_size]), 1) == torch.argmax(y[current: current + run_size], 1))
			current += run_size

		return (correct.float().item() / num_data)


	def training_step(self, model: nn.Module, optimizer: optim.Optimizer, criterion: nn.modules.loss._Loss, X: Tensor, y: Tensor):
		out: Tensor = model(X)

		loss: Tensor = criterion(out, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		return loss
	

	def fit(self, X: Tensor, y: Tensor, sup_samples, epochs = 100, batch_size = 64):

		X_sup, y_sup, X_val, y_val = supervised_samples(X, y, sup_samples)

		datasets = CustomDataSet(X, y)

		sup_datasets = CustomDataSet(X_sup, y_sup)

		dataloader = DataLoader(datasets, batch_size=batch_size, shuffle=True)

		sup_dataloader = DataLoader(sup_datasets, batch_size=batch_size//2, shuffle=True)

		for epoch in range(epochs):
			self.classify.train()
			self.discriminator.train()
			self.generator.train()

			print(f"epoch: {epoch}\nclassify: ")
			
			# for classify
			for inputs, labels in tqdm(sup_dataloader):
				sup_loss = self.training_step(self.classify, self.classify.optimizer, self.classify.criterion, inputs.to(device), labels.to(device))

			print(f'GAN:')
			# for discriminator and generator
			for inputs, _ in tqdm(dataloader):
				real_loss = self.training_step(self.discriminator, self.discriminator.optimizer, self.discriminator.criterion, inputs, torch.ones((inputs.shape[0], 1)).to(device))

				z = torch.randn((inputs.shape[0], self.latent_size)).to(device)
				gen_out = self.generator(z)
				fake_loss = self.training_step(self.discriminator, self.discriminator.optimizer, self.discriminator.criterion, gen_out, torch.zeros(inputs.shape[0], 1).to(device))
				gen_out = self.generator(z)
				gen_loss = self.training_step(self.discriminator, self.generator.optimizer, self.discriminator.criterion, gen_out, torch.ones((inputs.shape[0], 1)).to(device))

			train_acc = self.validation(X_sup, y_sup)

			val_acc = self.validation(X_val, y_val)
			
			print(f"train acc: {train_acc*100:.2f}%, val acc: {val_acc*100:.2f}%, classification_loss: {sup_loss:.2f}, discrimination_loss: {(real_loss+fake_loss)/2:.2f}, generation_loss: {gen_loss:.2f}")


In [268]:
model = SGAN([1, 28, 28], n_classes, ConvModel(1))

In [269]:
model.fit(data_X, data_Y, sup_samples=num_labelled, epochs=100, batch_size=64)

epoch: 0
classify: 


100%|██████████| 4/4 [00:00<00:00, 366.76it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 69.04it/s]


train acc: 15.00%, val acc: 4.00%, classification_loss: 0.30, discrimination_loss: 0.02, generation_loss: 5.16
epoch: 1
classify: 


100%|██████████| 4/4 [00:00<00:00, 545.48it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 68.71it/s]


train acc: 20.00%, val acc: 16.00%, classification_loss: 0.33, discrimination_loss: 0.14, generation_loss: 3.33
epoch: 2
classify: 


100%|██████████| 4/4 [00:00<00:00, 490.58it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 67.11it/s]


train acc: 20.00%, val acc: 8.00%, classification_loss: 0.25, discrimination_loss: 0.24, generation_loss: 3.77
epoch: 3
classify: 


100%|██████████| 4/4 [00:00<00:00, 417.78it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 67.45it/s]


train acc: 28.00%, val acc: 8.00%, classification_loss: 0.34, discrimination_loss: 0.16, generation_loss: 3.47
epoch: 4
classify: 


100%|██████████| 4/4 [00:00<00:00, 430.13it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 63.65it/s]


train acc: 40.00%, val acc: 8.00%, classification_loss: 0.25, discrimination_loss: 0.10, generation_loss: 4.39
epoch: 5
classify: 


100%|██████████| 4/4 [00:00<00:00, 486.04it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 67.88it/s]


train acc: 44.00%, val acc: 24.00%, classification_loss: 0.25, discrimination_loss: 0.06, generation_loss: 4.39
epoch: 6
classify: 


100%|██████████| 4/4 [00:00<00:00, 551.34it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 65.96it/s]


train acc: 51.00%, val acc: 28.00%, classification_loss: 0.25, discrimination_loss: 0.08, generation_loss: 4.36
epoch: 7
classify: 


100%|██████████| 4/4 [00:00<00:00, 546.86it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.63it/s]


train acc: 61.00%, val acc: 40.00%, classification_loss: 0.29, discrimination_loss: 0.04, generation_loss: 4.42
epoch: 8
classify: 


100%|██████████| 4/4 [00:00<00:00, 296.88it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.04it/s]


train acc: 61.00%, val acc: 36.00%, classification_loss: 0.28, discrimination_loss: 0.02, generation_loss: 5.47
epoch: 9
classify: 


100%|██████████| 4/4 [00:00<00:00, 376.61it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 67.69it/s]


train acc: 66.00%, val acc: 56.00%, classification_loss: 0.26, discrimination_loss: 0.13, generation_loss: 4.18
epoch: 10
classify: 


100%|██████████| 4/4 [00:00<00:00, 504.81it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 70.08it/s]


train acc: 69.00%, val acc: 40.00%, classification_loss: 0.23, discrimination_loss: 0.09, generation_loss: 4.29
epoch: 11
classify: 


100%|██████████| 4/4 [00:00<00:00, 525.52it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 70.57it/s]


train acc: 57.00%, val acc: 36.00%, classification_loss: 0.23, discrimination_loss: 1.56, generation_loss: 9.96
epoch: 12
classify: 


100%|██████████| 4/4 [00:00<00:00, 534.41it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 69.02it/s]


train acc: 55.00%, val acc: 40.00%, classification_loss: 0.26, discrimination_loss: 1.21, generation_loss: 8.18
epoch: 13
classify: 


100%|██████████| 4/4 [00:00<00:00, 362.04it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 67.74it/s]


train acc: 74.00%, val acc: 60.00%, classification_loss: 0.19, discrimination_loss: 0.04, generation_loss: 4.95
epoch: 14
classify: 


100%|██████████| 4/4 [00:00<00:00, 511.89it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 70.00it/s]


train acc: 80.00%, val acc: 52.00%, classification_loss: 0.25, discrimination_loss: 0.05, generation_loss: 4.95
epoch: 15
classify: 


100%|██████████| 4/4 [00:00<00:00, 537.09it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 69.92it/s]


train acc: 81.00%, val acc: 56.00%, classification_loss: 0.15, discrimination_loss: 0.13, generation_loss: 3.90
epoch: 16
classify: 


100%|██████████| 4/4 [00:00<00:00, 469.96it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 70.05it/s]


train acc: 77.00%, val acc: 52.00%, classification_loss: 0.17, discrimination_loss: 0.06, generation_loss: 6.67
epoch: 17
classify: 


100%|██████████| 4/4 [00:00<00:00, 524.63it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 70.05it/s]


train acc: 80.00%, val acc: 60.00%, classification_loss: 0.23, discrimination_loss: 0.15, generation_loss: 4.00
epoch: 18
classify: 


100%|██████████| 4/4 [00:00<00:00, 313.70it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 70.28it/s]


train acc: 84.00%, val acc: 64.00%, classification_loss: 0.25, discrimination_loss: 0.05, generation_loss: 4.24
epoch: 19
classify: 


100%|██████████| 4/4 [00:00<00:00, 420.43it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 70.70it/s]


train acc: 86.00%, val acc: 60.00%, classification_loss: 0.17, discrimination_loss: 0.02, generation_loss: 5.27
epoch: 20
classify: 


100%|██████████| 4/4 [00:00<00:00, 350.94it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 66.53it/s]


train acc: 86.00%, val acc: 60.00%, classification_loss: 0.19, discrimination_loss: 0.15, generation_loss: 4.26
epoch: 21
classify: 


100%|██████████| 4/4 [00:00<00:00, 436.59it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.48it/s]


train acc: 87.00%, val acc: 60.00%, classification_loss: 0.19, discrimination_loss: 0.04, generation_loss: 4.72
epoch: 22
classify: 


100%|██████████| 4/4 [00:00<00:00, 408.27it/s]


GAN:


100%|██████████| 938/938 [00:16<00:00, 57.98it/s]


train acc: 84.00%, val acc: 60.00%, classification_loss: 0.17, discrimination_loss: 0.05, generation_loss: 4.51
epoch: 23
classify: 


100%|██████████| 4/4 [00:00<00:00, 382.17it/s]


GAN:


100%|██████████| 938/938 [00:15<00:00, 59.57it/s]


train acc: 89.00%, val acc: 64.00%, classification_loss: 0.15, discrimination_loss: 1.55, generation_loss: 8.74
epoch: 24
classify: 


100%|██████████| 4/4 [00:00<00:00, 397.94it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 62.72it/s]


train acc: 86.00%, val acc: 64.00%, classification_loss: 0.19, discrimination_loss: 0.06, generation_loss: 4.35
epoch: 25
classify: 


100%|██████████| 4/4 [00:00<00:00, 399.46it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.36it/s]


train acc: 92.00%, val acc: 56.00%, classification_loss: 0.13, discrimination_loss: 0.03, generation_loss: 5.33
epoch: 26
classify: 


100%|██████████| 4/4 [00:00<00:00, 325.22it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.06it/s]


train acc: 91.00%, val acc: 64.00%, classification_loss: 0.21, discrimination_loss: 0.04, generation_loss: 4.74
epoch: 27
classify: 


100%|██████████| 4/4 [00:00<00:00, 345.79it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.37it/s]


train acc: 91.00%, val acc: 64.00%, classification_loss: 0.18, discrimination_loss: 0.10, generation_loss: 4.47
epoch: 28
classify: 


100%|██████████| 4/4 [00:00<00:00, 525.88it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 70.50it/s]


train acc: 93.00%, val acc: 68.00%, classification_loss: 0.17, discrimination_loss: 0.38, generation_loss: 6.86
epoch: 29
classify: 


100%|██████████| 4/4 [00:00<00:00, 425.38it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.84it/s]


train acc: 91.00%, val acc: 72.00%, classification_loss: 0.16, discrimination_loss: 0.01, generation_loss: 6.06
epoch: 30
classify: 


100%|██████████| 4/4 [00:00<00:00, 513.54it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 63.37it/s]


train acc: 94.00%, val acc: 68.00%, classification_loss: 0.20, discrimination_loss: 0.01, generation_loss: 5.55
epoch: 31
classify: 


100%|██████████| 4/4 [00:00<00:00, 411.22it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.36it/s]


train acc: 93.00%, val acc: 72.00%, classification_loss: 0.16, discrimination_loss: 0.02, generation_loss: 5.69
epoch: 32
classify: 


100%|██████████| 4/4 [00:00<00:00, 506.02it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 70.93it/s]


train acc: 93.00%, val acc: 64.00%, classification_loss: 0.16, discrimination_loss: 0.05, generation_loss: 5.65
epoch: 33
classify: 


100%|██████████| 4/4 [00:00<00:00, 537.35it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 69.13it/s]


train acc: 94.00%, val acc: 64.00%, classification_loss: 0.22, discrimination_loss: 0.19, generation_loss: 3.38
epoch: 34
classify: 


100%|██████████| 4/4 [00:00<00:00, 355.16it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 65.90it/s]


train acc: 94.00%, val acc: 68.00%, classification_loss: 0.10, discrimination_loss: 0.03, generation_loss: 5.14
epoch: 35
classify: 


100%|██████████| 4/4 [00:00<00:00, 486.76it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 62.59it/s]


train acc: 95.00%, val acc: 72.00%, classification_loss: 0.12, discrimination_loss: 0.07, generation_loss: 4.67
epoch: 36
classify: 


100%|██████████| 4/4 [00:00<00:00, 411.39it/s]


GAN:


100%|██████████| 938/938 [00:16<00:00, 57.40it/s]


train acc: 94.00%, val acc: 64.00%, classification_loss: 0.17, discrimination_loss: 0.10, generation_loss: 4.00
epoch: 37
classify: 


100%|██████████| 4/4 [00:00<00:00, 375.46it/s]


GAN:


100%|██████████| 938/938 [00:15<00:00, 59.85it/s]


train acc: 93.00%, val acc: 64.00%, classification_loss: 0.11, discrimination_loss: 0.14, generation_loss: 4.31
epoch: 38
classify: 


100%|██████████| 4/4 [00:00<00:00, 349.37it/s]


GAN:


100%|██████████| 938/938 [00:15<00:00, 59.32it/s]


train acc: 93.00%, val acc: 68.00%, classification_loss: 0.13, discrimination_loss: 0.14, generation_loss: 4.10
epoch: 39
classify: 


100%|██████████| 4/4 [00:00<00:00, 468.99it/s]


GAN:


100%|██████████| 938/938 [00:15<00:00, 61.12it/s]


train acc: 94.00%, val acc: 64.00%, classification_loss: 0.18, discrimination_loss: 0.37, generation_loss: 5.80
epoch: 40
classify: 


100%|██████████| 4/4 [00:00<00:00, 433.35it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 63.05it/s]


train acc: 94.00%, val acc: 68.00%, classification_loss: 0.14, discrimination_loss: 0.04, generation_loss: 4.87
epoch: 41
classify: 


100%|██████████| 4/4 [00:00<00:00, 525.85it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 63.16it/s]


train acc: 95.00%, val acc: 68.00%, classification_loss: 0.13, discrimination_loss: 0.02, generation_loss: 4.78
epoch: 42
classify: 


100%|██████████| 4/4 [00:00<00:00, 496.34it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 63.28it/s]


train acc: 97.00%, val acc: 72.00%, classification_loss: 0.23, discrimination_loss: 0.71, generation_loss: 3.34
epoch: 43
classify: 


100%|██████████| 4/4 [00:00<00:00, 493.67it/s]


GAN:


100%|██████████| 938/938 [00:15<00:00, 61.66it/s]


train acc: 98.00%, val acc: 68.00%, classification_loss: 0.15, discrimination_loss: 0.03, generation_loss: 7.57
epoch: 44
classify: 


100%|██████████| 4/4 [00:00<00:00, 338.07it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.58it/s]


train acc: 98.00%, val acc: 72.00%, classification_loss: 0.12, discrimination_loss: 0.24, generation_loss: 6.05
epoch: 45
classify: 


100%|██████████| 4/4 [00:00<00:00, 403.33it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 66.48it/s]


train acc: 98.00%, val acc: 72.00%, classification_loss: 0.12, discrimination_loss: 0.00, generation_loss: 6.21
epoch: 46
classify: 


100%|██████████| 4/4 [00:00<00:00, 412.77it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 64.65it/s]


train acc: 97.00%, val acc: 72.00%, classification_loss: 0.09, discrimination_loss: 0.05, generation_loss: 4.82
epoch: 47
classify: 


100%|██████████| 4/4 [00:00<00:00, 498.40it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 69.84it/s]


train acc: 97.00%, val acc: 68.00%, classification_loss: 0.09, discrimination_loss: 0.03, generation_loss: 4.71
epoch: 48
classify: 


100%|██████████| 4/4 [00:00<00:00, 484.86it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 68.52it/s]


train acc: 98.00%, val acc: 72.00%, classification_loss: 0.13, discrimination_loss: 0.08, generation_loss: 5.60
epoch: 49
classify: 


100%|██████████| 4/4 [00:00<00:00, 454.05it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 66.87it/s]


train acc: 98.00%, val acc: 72.00%, classification_loss: 0.08, discrimination_loss: 0.15, generation_loss: 5.57
epoch: 50
classify: 


100%|██████████| 4/4 [00:00<00:00, 525.60it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 67.19it/s]


train acc: 97.00%, val acc: 72.00%, classification_loss: 0.08, discrimination_loss: 0.50, generation_loss: 5.89
epoch: 51
classify: 


100%|██████████| 4/4 [00:00<00:00, 470.65it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 67.17it/s]


train acc: 98.00%, val acc: 80.00%, classification_loss: 0.13, discrimination_loss: 0.01, generation_loss: 5.30
epoch: 52
classify: 


100%|██████████| 4/4 [00:00<00:00, 499.90it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 66.82it/s]


train acc: 98.00%, val acc: 76.00%, classification_loss: 0.11, discrimination_loss: 0.36, generation_loss: 5.07
epoch: 53
classify: 


100%|██████████| 4/4 [00:00<00:00, 494.47it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 67.87it/s]


train acc: 98.00%, val acc: 80.00%, classification_loss: 0.12, discrimination_loss: 0.01, generation_loss: 6.00
epoch: 54
classify: 


100%|██████████| 4/4 [00:00<00:00, 489.86it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 62.99it/s]


train acc: 99.00%, val acc: 80.00%, classification_loss: 0.16, discrimination_loss: 0.02, generation_loss: 4.81
epoch: 55
classify: 


100%|██████████| 4/4 [00:00<00:00, 199.82it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 66.27it/s]


train acc: 100.00%, val acc: 76.00%, classification_loss: 0.07, discrimination_loss: 0.54, generation_loss: 8.74
epoch: 56
classify: 


100%|██████████| 4/4 [00:00<00:00, 551.21it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 69.58it/s]


train acc: 99.00%, val acc: 80.00%, classification_loss: 0.09, discrimination_loss: 0.30, generation_loss: 4.29
epoch: 57
classify: 


100%|██████████| 4/4 [00:00<00:00, 446.90it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 69.81it/s]


train acc: 100.00%, val acc: 76.00%, classification_loss: 0.06, discrimination_loss: 0.10, generation_loss: 5.54
epoch: 58
classify: 


100%|██████████| 4/4 [00:00<00:00, 508.11it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 68.67it/s]


train acc: 100.00%, val acc: 76.00%, classification_loss: 0.08, discrimination_loss: 0.04, generation_loss: 4.75
epoch: 59
classify: 


100%|██████████| 4/4 [00:00<00:00, 275.89it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 66.29it/s]


train acc: 99.00%, val acc: 76.00%, classification_loss: 0.10, discrimination_loss: 0.16, generation_loss: 3.63
epoch: 60
classify: 


100%|██████████| 4/4 [00:00<00:00, 422.32it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 65.68it/s]


train acc: 100.00%, val acc: 72.00%, classification_loss: 0.09, discrimination_loss: 0.80, generation_loss: 9.16
epoch: 61
classify: 


100%|██████████| 4/4 [00:00<00:00, 443.52it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 67.73it/s]


train acc: 100.00%, val acc: 80.00%, classification_loss: 0.13, discrimination_loss: 0.21, generation_loss: 4.18
epoch: 62
classify: 


100%|██████████| 4/4 [00:00<00:00, 465.62it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 65.86it/s]


train acc: 100.00%, val acc: 72.00%, classification_loss: 0.12, discrimination_loss: 0.04, generation_loss: 5.18
epoch: 63
classify: 


100%|██████████| 4/4 [00:00<00:00, 501.94it/s]


GAN:


100%|██████████| 938/938 [00:13<00:00, 69.57it/s]


train acc: 100.00%, val acc: 76.00%, classification_loss: 0.12, discrimination_loss: 0.12, generation_loss: 4.04
epoch: 64
classify: 


100%|██████████| 4/4 [00:00<00:00, 525.22it/s]


GAN:


100%|██████████| 938/938 [00:14<00:00, 66.57it/s]


train acc: 100.00%, val acc: 80.00%, classification_loss: 0.07, discrimination_loss: 0.03, generation_loss: 5.39
epoch: 65
classify: 


100%|██████████| 4/4 [00:00<00:00, 327.67it/s]


GAN:


 20%|█▉        | 187/938 [00:03<00:12, 59.83it/s]


KeyboardInterrupt: 

In [270]:
model.validation(data_X, data_Y)

0.7902