In [1]:
import torch
from torch import optim
from torch import Tensor
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as tt
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import random_split

from utils import *
import config
import random

from typing import Type

from Network_model import Generator, ConvModel

from tqdm import tqdm, trange
import copy 

from Network_model import Generator, Discriminator

In [2]:
print_config()

RANDOM_SEED   :  11042004
DATA_DIR      :    ./data
USED_DATA     :     MNIST
NUM_LABELLED  :       100
DEVICE        :    cuda:0
EPOCHS        :        10
BATCH_SIZE    :       128
LEARNING_RATE :    0.0004
SCHED         :     False
GAN_BATCH_SIZE:       128


In [3]:
set_random_seed(config.RANDOM_SEED)
random.seed(config.RANDOM_SEED)

Setting seeds ...... 



In [4]:
name = "MarginGAN"

In [5]:
if config.USED_DATA == "CIFAR10":
	mean = [0.5]*3
	std = [0.5]*3

	train_tfm = tt.Compose([
		tt.Resize(32),
		tt.RandomCrop(32, padding=4, padding_mode='edge'),
		tt.RandomHorizontalFlip(),
		tt.Normalize(mean, std, inplace=True)
	])

if config.USED_DATA == "MNIST":
	mean = [0.5]
	std = [0.5]
	train_tfm = tt.Compose([
		tt.Resize(32),
		tt.Normalize(mean, std, inplace=True)
	])

test_tfm = tt.Compose([
	tt.Resize(32),
	tt.Normalize(mean, std)
])

In [6]:
train_ds, test_ds, classes = load_data(train_tfm, test_tfm)

In [7]:
X_full = train_ds.x
y_full = train_ds.y

In [8]:
print(X_full[0])

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

In [9]:
classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

In [10]:
n_classes = len(classes)
channels = train_ds[0][0].shape[0] # MNIST
n_classes, channels

(10, 1)

In [11]:
X_sup, y_sup, X_unsup, _ = supervised_samples(X_full, y_full, config.NUM_LABELLED, n_classes, get_unsup=True)

In [12]:
test_dl = CreateDataLoader(test_ds, batch_size=512, transform=test_tfm, device=config.DEVICE)

In [13]:
class Classifier(nn.Module):
    def __init__(self, in_channels, n_classes) -> None:
        super().__init__()

        self.Conv = ConvModel(in_channels)

        self.dropout = nn.Dropout(0.5)

        self.out = nn.Linear(512, n_classes)
        
    def forward(self, X: Tensor):
        out = self.Conv(X)
        out = self.dropout(out)
        out = self.out(out) 
        return out

In [14]:
class MarginGAN: 
    def __init__(self, latent_size, n_channels, n_classes, device): 
        self.latent_size = latent_size
        self.n_classes = n_classes
        self.generator = Generator(latent_size, n_channels)
        self.discriminator = Discriminator(n_channels, 1) 
        self.classifier = Classifier(n_channels, n_classes)

        self.CEloss = nn.CrossEntropyLoss() 
        self.BCEloss = nn.BCELoss()
        self.device = device 
        self.to(device)
    
    def to(self, device): 
        self.generator.to(device) 
        self.discriminator.to(device)
        self.classifier.to(device)
    
    def load_gen_state_dict(self, file):
        self.generator.load_state_dict(torch.load(file))
    
    def accuracy(self, test_dl): 
        corrected = 0
        for b in tqdm(test_dl):
            images, y = b
            outs = self.classifier.forward(images)
            outs = torch.argmax(outs[:, :-1], dim=1)
            corrected += (outs == y).sum().item()
        return corrected / test_dl.num_data()
    
    def discriminator_step(self, real_imgs:torch.Tensor, batch_size): 
        # Train discriminator to recognize real imgs as real imgs
        real_batch_size = real_imgs.shape[0]
        outs = self.discriminator(real_imgs)
        outs = F.softmax(outs, dim = 1)[:, 0]
        # outs is probability of discriminator predict fake images
        # because this is real images, we want this {outs} to be minimize
        y_hat = torch.zeros([real_batch_size], device = self.device) 
        real_loss = self.BCEloss(outs, y_hat)

        # Train discriminator to recognize fake imgs as fake imgs 
        z = torch.randn([batch_size, self.latent_size, 1, 1], device = self.device)
        fake_imgs = self.generator(z)
        fake_outs = self.discriminator(fake_imgs)
        fake_outs = F.softmax(fake_outs, dim = 1)[:, 1]
        # outs is probability of discriminator predict fake images
        # because this is fake images, we want this {outs} to be maximize
        fake_y_hat = torch.ones([batch_size], device = self.device) 
        fake_loss = self.BCEloss(fake_outs, fake_y_hat)
        return real_loss + fake_loss
    
    def classifier_step(self, sup_imgs, sup_labels, unsup_imgs, batch_size): 
        # Loss for labeled samples
        sup_outs = self.classifier(sup_imgs) 
        sup_loss = self.CEloss(sup_outs, sup_labels)

        # Loss for unlabeled samples
        # Pseudo_label:  Pick up the class which
        # has maximum predicted probability for each unlabeled
        # sample
        unsup_outs = self.classifier(unsup_imgs) 
        unsup_pseudolabels = torch.argmax(unsup_outs, dim = 1) 
        # print(unsup_pseudolabels.shape)
        unsup_loss = self.CEloss(unsup_outs, unsup_pseudolabels)

        # Loss for generated samples. Also pseudo_labelling as for 
        # unsup imgs, but now apply the inverted binary cross entropy 
        # as loss. Aim: decrease the margin of these data points
        # and make the prediction distribution flat
        z = torch.randn([batch_size, self.latent_size, 1, 1], device = self.device)
        fake_imgs = self.generator(z) 
        fake_outs = self.classifier(fake_imgs)
        fake_pseudolabels = torch.argmin(fake_outs, dim = 1) 
        fake_loss = self.CEloss(fake_outs, fake_pseudolabels) 

        return sup_loss + unsup_loss + fake_loss

    def fit(self, epochs, batch_size, batch_per_epoch, dis_lr, max_lr, sup_ds:CustomDataSet, unsup_ds:CustomDataSet, full_ds:CustomDataSet, test_dl, optim:Type[optim.Optimizer], weight_decay = 0, sched = True, PATH = ".", save_best = False, grad_clip = False): 
        optimizerD = optim(self.discriminator.parameters(), lr = dis_lr)
        optimizerC = optim(self.classifier.parameters(), lr = max_lr, weight_decay = weight_decay)

        if sched: 
            OneCycleLR = torch.optim.lr_scheduler.OneCycleLR(optimizerC, max_lr, epochs=epochs, steps_per_epoch=batch_per_epoch)
        with open('acc.txt', 'w') as f:
            self.discriminator.train()
            self.classifier.train() 
            best_val = 0 
            for epoch in (range(epochs)):
                for i in (trange(batch_per_epoch)): 
                    sup_imgs, labels = random_split(sup_ds, [batch_size, len(sup_ds) - batch_size])[0][:]
                    full_imgs = random_split(full_ds, [batch_size, len(full_ds) - batch_size])[0][:]
                    unsup_imgs = random_split(unsup_ds, [batch_size, len(unsup_ds) - batch_size])[0][:]
                    # train discriminator
                    if (i%2 == 1):
                        D_loss = self.discriminator_step(full_imgs.to(self.device), batch_size)
                        D_loss.backward()
                        optimizerD.step()
                        if grad_clip: 
                            torch.nn.utils.clip_grad_value_(self.discriminator.parameters(), 0.1)
                    # train classifier
                    C_loss = self.classifier_step(sup_imgs.to(self.device), labels.to(self.device), unsup_imgs.to(self.device), batch_size) 
                    C_loss.backward()
                    
                    if grad_clip: 
                        torch.nn.utils.clip_grad_value_(self.classifier.parameters(), 0.1)
                    
                    optimizerC.step()
                    optimizerC.zero_grad()
                    if sched: 
                        OneCycleLR.step()
                    
                    tqdm.write(f'C_loss: {C_loss.detach().item()}', end = "\r")
                    
                self.discriminator.eval()
                self.classifier.eval()
                acc = self.accuracy(test_dl)
                if acc > best_val: 
                    best_val = acc 
                    if save_best: 
                        torch.save(self.classifier.state_dict(), PATH)
                f.write(f'accuracy: {best_val}\n')
                tqdm.write(f'accuracy: {best_val}', end = "\r")

In [15]:
mGAN = MarginGAN(100, channels, n_classes, config.DEVICE)

In [16]:
mGAN.classifier

Classifier(
  (Conv): ConvModel(
    (initial): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (Conv): Sequential(
      (0): ConvBn(
        (Conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (1): ConvBn(
        (Conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (2): ConvBn(
        (Conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(512, eps=1e-05

In [17]:
i = 4
mGAN.load_gen_state_dict(f"DCGAN/{config.USED_DATA}/netG_epoch_00{i}.pth")

In [18]:
sup_ds = CustomDataSet(X_sup, y_sup, train_tfm)
full_ds = CustomDataSet(X_full, None, train_tfm)
unsup_ds = CustomDataSet(X_unsup, None, train_tfm)

In [19]:
mGAN.fit(10, 64, 50, 1e-5, 1e-4, sup_ds, unsup_ds, full_ds, test_dl, optim.RMSprop, save_best = True, grad_clip=False)

  2%|▏         | 1/50 [00:00<00:47,  1.03it/s]

C_loss: 12.800742149353027

  4%|▍         | 2/50 [00:01<00:33,  1.45it/s]

C_loss: 11.36173152923584

  6%|▌         | 3/50 [00:01<00:23,  2.03it/s]

C_loss: 10.425538063049316

  8%|▊         | 4/50 [00:02<00:22,  2.03it/s]

C_loss: 10.206046104431152

 10%|█         | 5/50 [00:02<00:18,  2.46it/s]

C_loss: 10.247130393981934

 12%|█▏        | 6/50 [00:02<00:18,  2.32it/s]

C_loss: 10.265443801879883

 14%|█▍        | 7/50 [00:03<00:16,  2.67it/s]

C_loss: 10.52142333984375

 16%|█▌        | 8/50 [00:03<00:16,  2.47it/s]

C_loss: 10.16124439239502

 18%|█▊        | 9/50 [00:03<00:14,  2.76it/s]

C_loss: 9.590926170349121

 20%|██        | 10/50 [00:04<00:15,  2.53it/s]

C_loss: 9.104801177978516

 22%|██▏       | 11/50 [00:04<00:13,  2.82it/s]

C_loss: 9.124277114868164

 24%|██▍       | 12/50 [00:05<00:14,  2.54it/s]

C_loss: 9.147794723510742

 26%|██▌       | 13/50 [00:05<00:13,  2.82it/s]

C_loss: 8.918766975402832

 28%|██▊       | 14/50 [00:05<00:14,  2.57it/s]

C_loss: 8.908673286437988

 30%|███       | 15/50 [00:06<00:12,  2.84it/s]

C_loss: 8.592220306396484

 32%|███▏      | 16/50 [00:06<00:13,  2.58it/s]

C_loss: 8.80207633972168

 34%|███▍      | 17/50 [00:06<00:11,  2.85it/s]

C_loss: 8.347112655639648

 36%|███▌      | 18/50 [00:07<00:12,  2.58it/s]

C_loss: 8.675395965576172

 38%|███▊      | 19/50 [00:07<00:10,  2.86it/s]

C_loss: 8.040486335754395

 40%|████      | 20/50 [00:08<00:11,  2.57it/s]

C_loss: 7.793705940246582

 42%|████▏     | 21/50 [00:08<00:10,  2.85it/s]

C_loss: 7.9863176345825195

 44%|████▍     | 22/50 [00:08<00:10,  2.58it/s]

C_loss: 7.956389904022217

 46%|████▌     | 23/50 [00:09<00:09,  2.86it/s]

C_loss: 7.730774879455566

 48%|████▊     | 24/50 [00:09<00:10,  2.58it/s]

C_loss: 7.785335063934326

 50%|█████     | 25/50 [00:09<00:08,  2.86it/s]

C_loss: 7.8011298179626465

 52%|█████▏    | 26/50 [00:10<00:09,  2.56it/s]

C_loss: 7.837123870849609

 54%|█████▍    | 27/50 [00:10<00:08,  2.86it/s]

C_loss: 7.473788261413574

 56%|█████▌    | 28/50 [00:11<00:08,  2.58it/s]

C_loss: 7.360495567321777

 58%|█████▊    | 29/50 [00:11<00:07,  2.86it/s]

C_loss: 7.2151570320129395

 60%|██████    | 30/50 [00:11<00:07,  2.57it/s]

C_loss: 7.305842399597168

 62%|██████▏   | 31/50 [00:12<00:06,  2.83it/s]

C_loss: 7.250226974487305

 64%|██████▍   | 32/50 [00:12<00:07,  2.55it/s]

C_loss: 7.06805419921875

 66%|██████▌   | 33/50 [00:12<00:05,  2.84it/s]

C_loss: 6.880336284637451

 68%|██████▊   | 34/50 [00:13<00:06,  2.56it/s]

C_loss: 7.107601642608643

 70%|███████   | 35/50 [00:13<00:05,  2.83it/s]

C_loss: 7.061164379119873

 72%|███████▏  | 36/50 [00:14<00:05,  2.55it/s]

C_loss: 6.8570027351379395

 74%|███████▍  | 37/50 [00:14<00:04,  2.84it/s]

C_loss: 6.865920543670654

 76%|███████▌  | 38/50 [00:14<00:04,  2.59it/s]

C_loss: 6.959708213806152

 78%|███████▊  | 39/50 [00:15<00:03,  2.88it/s]

C_loss: 6.574657440185547

 80%|████████  | 40/50 [00:15<00:03,  2.56it/s]

C_loss: 6.825980186462402

 82%|████████▏ | 41/50 [00:15<00:03,  2.83it/s]

C_loss: 6.531682014465332

 84%|████████▍ | 42/50 [00:16<00:03,  2.56it/s]

C_loss: 6.886235237121582

 86%|████████▌ | 43/50 [00:16<00:02,  2.85it/s]

C_loss: 6.399694442749023

 88%|████████▊ | 44/50 [00:17<00:02,  2.56it/s]

C_loss: 6.187121391296387

 90%|█████████ | 45/50 [00:17<00:01,  2.86it/s]

C_loss: 6.530207633972168

 92%|█████████▏| 46/50 [00:17<00:01,  2.58it/s]

C_loss: 6.027438640594482

 94%|█████████▍| 47/50 [00:17<00:01,  2.86it/s]

C_loss: 6.169809341430664

 96%|█████████▌| 48/50 [00:18<00:00,  2.57it/s]

C_loss: 6.266961097717285

 98%|█████████▊| 49/50 [00:18<00:00,  2.85it/s]

C_loss: 5.927544116973877

100%|██████████| 50/50 [00:19<00:00,  2.60it/s]


C_loss: 5.856419563293457

100%|██████████| 20/20 [00:02<00:00,  8.42it/s]


RuntimeError: [enforce fail at inline_container.cc:633] . invalid file name: .

In [None]:
corrected = 0
for b in tqdm(test_dl):
    images, y = b
    outs = mGAN.classifier.forward(images)
    outs = torch.argmax(outs, dim=1)
    corrected += (outs == y).sum().item()

100%|██████████| 20/20 [00:01<00:00, 11.04it/s]


In [None]:
print(f"{i}: {corrected / test_dl.num_data()}")

4: 0.3316
