In [1]:
import torch
from torch import optim
from torch import Tensor
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as tt
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import random_split

from utils import *
import config
import random

from typing import Type

from Classify import Classifier
from Network_model import Generator, ConvModel

from tqdm import tqdm
import copy 

from Network_model import Generator, Discriminator

In [2]:
print_config()

RANDOM_SEED   :  11042004
DATA_DIR      :    ./data
USED_DATA     :     MNIST
NUM_LABELLED  :       100
DEVICE        :    cuda:0
EPOCHS        :        10
BATCH_SIZE    :       128
LEARNING_RATE :    0.0004
SCHED         :     False
GAN_BATCH_SIZE:       128


In [3]:
set_random_seed(config.RANDOM_SEED)
random.seed(config.RANDOM_SEED)

Setting seeds ...... 



In [4]:
name = "MarginGAN"

In [5]:
if config.USED_DATA == "CIFAR10":
	mean = [0.5]*3
	std = [0.5]*3

	train_tfm = tt.Compose([
		tt.RandomCrop(32, padding=4, padding_mode='edge'),
		tt.RandomHorizontalFlip(),
		tt.Normalize(mean, std, inplace=True)
	])

if config.USED_DATA == "MNIST":
	mean = [0.5]
	std = [0.5]
	train_tfm = tt.Compose([
		tt.Resize(32),
		tt.Normalize(mean, std, inplace=True)
	])

test_tfm = tt.Compose([
	tt.Resize(32),
	tt.Normalize(mean, std)
])

In [6]:
train_ds, test_ds, classes = load_data(train_tfm, test_tfm)

In [7]:
X_full = train_ds.x
y_full = train_ds.y

In [8]:
classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

In [9]:
n_classes = len(classes)
channels = train_ds[0][0].shape[0] # MNIST
n_classes, channels

(10, 1)

In [10]:
X_sup, y_sup, X_unsup, _ = supervised_samples(X_full, y_full, config.NUM_LABELLED, n_classes, get_unsup=True)

In [11]:
test_dl = CreateDataLoader(test_ds, batch_size=512, transform=test_tfm, device=config.DEVICE)

In [12]:
class Classifier(nn.Module):
    def __init__(self, in_channels, n_classes) -> None:
        super().__init__()

        self.Conv = ConvModel(in_channels)

        self.dropout = nn.Dropout(0.5)

        self.out = nn.Linear(512, n_classes)
        
    def forward(self, X: Tensor):
        out = self.Conv(X)
        out = self.dropout(out)
        out = self.out(out) 
        return out

In [13]:
class MarginGAN: 
    def __init__(self, latent_size, n_channels, n_classes, device): 
        self.latent_size = latent_size
        self.n_classes = n_classes
        self.generator = Generator(latent_size, n_channels)
        self.discriminator = Discriminator(n_channels, 1) 
        self.classifier = Classifier(n_channels, n_classes)

        self.CEloss = nn.CrossEntropyLoss() 
        self.BCEloss = nn.BCELoss()
        self.device = device 
        self.to(device)
    
    def to(self, device): 
        self.generator.to(device) 
        self.discriminator.to(device)
        self.classifier.to(device)
    
    def load_gen_state_dict(self, file):
        self.generator.load_state_dict(torch.load(file))
    
    def accuracy(self, test_dl): 
        corrected = 0
        for b in tqdm(test_dl):
            images, y = b
            outs = self.classifier.forward(images)
            outs = torch.argmax(outs[:, :-1], dim=1)
            corrected += (outs == y).sum().item()
        return corrected / test_dl.num_data()
    
    def discriminator_step(self, real_imgs:torch.Tensor, batch_size): 
        # Train discriminator to recognize real imgs as real imgs
        real_batch_size = real_imgs.shape[0]
        outs = self.discriminator(real_imgs)
        outs = F.softmax(outs, dim = 1)[:, 0]
        # outs is probability of discriminator predict fake images
        # because this is real images, we want this {outs} to be minimize
        y_hat = torch.zeros([real_batch_size], device = self.device) 
        real_loss = self.BCEloss(outs, y_hat)

        # Train discriminator to recognize fake imgs as fake imgs 
        z = torch.randn([batch_size, self.latent_size, 1, 1], device = self.device)
        fake_imgs = self.generator(z)
        fake_outs = self.discriminator(fake_imgs)
        fake_outs = F.softmax(fake_outs, dim = 1)[:, 1]
        # outs is probability of discriminator predict fake images
        # because this is fake images, we want this {outs} to be maximize
        fake_y_hat = torch.ones([batch_size], device = self.device) 
        fake_loss = self.BCEloss(fake_outs, fake_y_hat)
        return real_loss + fake_loss
    
    def classifier_step(self, sup_imgs, sup_labels, unsup_imgs, batch_size): 
        # Loss for labeled samples
        sup_outs = self.classifier(sup_imgs) 
        sup_loss = self.CEloss(sup_outs, sup_labels)

        # Loss for unlabeled samples
        # Pseudo_label:  Pick up the class which
        # has maximum predicted probability for each unlabeled
        # sample
        unsup_outs = self.classifier(unsup_imgs) 
        unsup_pseudolabels = torch.argmax(unsup_outs, dim = 1) 
        # print(unsup_pseudolabels.shape)
        unsup_loss = self.CEloss(unsup_outs, unsup_pseudolabels)

        # Loss for generated samples. Also pseudo_labelling as for 
        # unsup imgs, but now apply the inverted binary cross entropy 
        # as loss. Aim: decrease the margin of these data points
        # and make the prediction distribution flat
        z = torch.randn([batch_size, self.latent_size, 1, 1], device = self.device)
        fake_imgs = self.generator(z) 
        fake_outs = self.classifier(fake_imgs)
        fake_pseudolabels = torch.argmin(fake_outs, dim = 1) 
        fake_loss = self.CEloss(fake_outs, fake_pseudolabels) 

        return sup_loss + unsup_loss + fake_loss

    def fit(self, epochs, batch_size, batch_per_epoch, dis_lr, max_lr, sup_ds:CustomDataSet, unsup_ds:CustomDataSet, full_ds:CustomDataSet, test_dl, optim:Type[optim.Optimizer], weight_decay = 0, sched = True, PATH = ".", save_best = False, grad_clip = False): 
        optimizerD = optim(self.discriminator.parameters(), lr = dis_lr)
        optimizerC = optim(self.classifier.parameters(), lr = max_lr, weight_decay = weight_decay)

        if sched: 
            OneCycleLR = torch.optim.lr_scheduler.OneCycleLR(optimizerC, max_lr, epochs=epochs, steps_per_epoch=batch_per_epoch)

        self.discriminator.train()
        self.classifier.train() 
        for epoch in (range(epochs)):
            for i in range(batch_per_epoch): 
                sup_imgs, labels = random_split(sup_ds, [batch_size, len(sup_ds) - batch_size])[0][:]
                full_imgs = random_split(full_ds, [batch_size, len(full_ds) - batch_size])[0][:]
                unsup_imgs = random_split(unsup_ds, [batch_size, len(unsup_ds) - batch_size])[0][:]
                # train discriminator
                if (i%2 == 1):
                    D_loss = self.discriminator_step(full_imgs.to(self.device), batch_size)
                    D_loss.backward()
                    optimizerD.step()
                    tqdm.write(f'D_loss: {D_loss.detach().item()}', end = "\r")
                # train classifier
                C_loss = self.classifier_step(sup_imgs.to(self.device), labels.to(self.device), unsup_imgs.to(self.device), batch_size) 
                C_loss.backward()
                
                if grad_clip: 
                    torch.nn.utils.clip_grad_value_(self.classifier.parameters(), 0.1)
                
                optimizerC.step()
                optimizerC.zero_grad()
                if sched: 
                    OneCycleLR.step()
                
                tqdm.write(f'C_loss: {C_loss.detach().item()}', end = "\r")
                
            self.discriminator.eval() 
            self.classifier.eval()
            tqdm.write(f'accuracy: {self.accuracy(test_dl)}', end = "\r")

In [14]:
mGAN = MarginGAN(100, channels, n_classes, config.DEVICE)

In [15]:
mGAN.classifier

Classifier(
  (Conv): ConvModel(
    (initial): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (Conv): Sequential(
      (0): ConvBn(
        (Conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (1): ConvBn(
        (Conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (2): ConvBn(
        (Conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(512, eps=1e-05

In [16]:
i = 4
mGAN.load_gen_state_dict(f"DCGAN/{config.USED_DATA}/netG_epoch_00{i}.pth")

In [17]:
sup_ds = CustomDataSet(X_sup, y_sup, train_tfm)
full_ds = CustomDataSet(X_full, None, train_tfm)
unsup_ds = CustomDataSet(X_unsup, None, train_tfm)

In [18]:
mGAN.fit(10, 64, 50, 1e-5, 2*1e-4, sup_ds, unsup_ds, full_ds, test_dl, optim.RMSprop, grad_clip=True)

C_loss: 5.4620990753173837821

100%|██████████| 20/20 [00:01<00:00, 12.49it/s]


C_loss: 2.65606307983398440997

100%|██████████| 20/20 [00:01<00:00, 14.15it/s]


C_loss: 2.6090240478515625e-12

100%|██████████| 20/20 [00:01<00:00, 14.14it/s]


C_loss: 2.5120520591735845e-18

100%|██████████| 20/20 [00:01<00:00, 13.81it/s]


C_loss: 2.4998936653137207e-24

100%|██████████| 20/20 [00:01<00:00, 14.06it/s]


C_loss: 2.4353010654449463-293

100%|██████████| 20/20 [00:01<00:00, 14.05it/s]


C_loss: 2.4172558784484863-301

100%|██████████| 20/20 [00:01<00:00, 13.91it/s]


C_loss: 2.3798391819000244-453

100%|██████████| 20/20 [00:01<00:00, 13.93it/s]


C_loss: 2.35726571083068854323

100%|██████████| 20/20 [00:01<00:00, 13.86it/s]


C_loss: 2.3383316993713386

100%|██████████| 20/20 [00:01<00:00, 13.65it/s]

accuracy: 0.7381




In [19]:
corrected = 0
for b in tqdm(test_dl):
    images, y = b
    outs = mGAN.classifier.forward(images)
    outs = torch.argmax(outs, dim=1)
    corrected += (outs == y).sum().item()

100%|██████████| 20/20 [00:01<00:00, 13.52it/s]


In [20]:
print(f"{i}: {corrected / test_dl.num_data()}")

4: 0.815
