In [1]:
import torch
from torch import optim
from torch import Tensor
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as tt
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import random_split

from utils import *
import config
import random

from typing import Type

from Classify import Classifier
from Network_model import Generator, ConvModel

from tqdm import tqdm
import copy 


In [2]:
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [3]:
print_config()

RANDOM_SEED   :  11042004
DATA_DIR      :    ./data
USED_DATA     :     MNIST
NUM_LABELLED  :       100
DEVICE        :    cuda:0
EPOCHS        :        10
BATCH_SIZE    :       128
LEARNING_RATE :    0.0004
SCHED         :     False
GAN_BATCH_SIZE:       128


In [4]:
set_random_seed(config.RANDOM_SEED)
random.seed(config.RANDOM_SEED)

Setting seeds ...... 



In [5]:
name = "GANSSL"

In [6]:
PATH = get_PATH(name)
PATH

'MNIST/GANSSL/_100'

In [7]:
if config.USED_DATA == "CIFAR10":
	mean = [0.5]*3
	std = [0.5]*3

	train_tfm = tt.Compose([
		tt.RandomCrop(32, padding=4, padding_mode='edge'),
		tt.RandomHorizontalFlip(),
		tt.Normalize(mean, std, inplace=True)
	])

if config.USED_DATA == "MNIST":
	mean = [0.5]
	std = [0.5]
	train_tfm = tt.Compose([
		# tt.Resize(32),
		tt.Normalize(mean, std, inplace=True)
	])

test_tfm = tt.Compose([
	# tt.Resize(32),
	tt.Normalize(mean, std)
])

In [8]:
train_ds, test_ds, classes = load_data(train_tfm, test_tfm)

In [9]:
X_full = train_ds.x
y_full = train_ds.y

In [10]:
test_ds[0][0].shape

torch.Size([1, 28, 28])

In [11]:
classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

In [12]:
n_classes = len(classes)
channels = train_ds[0][0].shape[0] # MNIST
n_classes, channels

(10, 1)

In [13]:
X_sup, y_sup, X_unsup, _ = supervised_samples(X_full, y_full, config.NUM_LABELLED, n_classes, get_unsup=True)

In [14]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, n_classes) -> None:
        super().__init__()

        self.Conv = ConvModel(in_channels)

        self.dropout = nn.Dropout(0.5)

        self.out = nn.Linear(512, n_classes+1)
        
    def forward(self, X: Tensor):
        out = self.Conv(X)
        out = self.out(out)
        return out

In [15]:
test_dl = CreateDataLoader(test_ds, batch_size=512, transform=test_tfm, device=config.DEVICE)

In [16]:
def wasserstein_loss(y_true, y_pred): 
     return torch.mean(y_true * y_pred, dtype=torch.float)

In [17]:
class GAN:
    def __init__(self, latent_size, n_channels, n_classes, device) -> None:
        self.generator = Generator(latent_size, n_channels)
        self.discriminator = Discriminator(n_channels, n_classes)

        self.latent_size = latent_size

        self.CEloss = nn.CrossEntropyLoss()
        self.BCEloss = nn.BCELoss()

        self.n_classes = n_classes

        self.resize = tt.Resize(32)

        self.device = device

        self.to(device)

    def to(self, device):
        self.generator = self.generator.to(device)
        self.discriminator = self.discriminator.to(device)
    
    def load_gen_state_dict(self, file):
        self.generator.load_state_dict(torch.load(file))
    
    def classifier_step(self, X, y):
        outs = self.discriminator(X)
        # loss = wasserstein_loss(outs.to(torch.float), y.to(torch.float))
        loss = self.CEloss(outs, y)
        
        # loss.backward()

        # optim.step()
        # tqdm.write(f'classifier_loss: {loss.detach().item()}', end = "\r")

        return loss
    
    def discriminator_real_step(self, X):
        
        batch_size = X.shape[0]
        outs = self.discriminator(X)
        outs = F.softmax(outs, dim=1)[:, -1]   # shape: B x 1
        # outs is probability of discriminator predict fake images
        # because this is real images, we want this {outs} to be minimize
        
        y_hat = torch.zeros([batch_size], device=self.device)

        loss = self.BCEloss(outs, y_hat)
        # loss.backward()

        # optim.step()
        # tqdm.write(f'disc_real_loss: {loss.detach().item()}', end = "\r")
        
        return loss
        

    
    def discriminator_fake_step(self, batch_size):
        z = torch.randn([batch_size, self.latent_size, 1, 1], device = self.device)
        fake_images = self.generator(z)

        fake_images = self.resize(fake_images)

        outs = self.discriminator(fake_images)
        y_hat = torch.full([batch_size], self.n_classes, device=self.device)

        loss = self.CEloss(outs, y_hat)
        # loss.backward()

        # optim.step()
        # tqdm.write(f'disc_fake_loss: {loss.detach().item()}', end = "\r")

        return loss


    def accuracy(self, test_dl): 
        corrected = 0
        for b in tqdm(test_dl):
            images, y = b
            outs = self.discriminator.forward(images)
            outs = torch.argmax(outs[:, :-1], dim=1)
            corrected += (outs == y).sum().item()
        return corrected / test_dl.num_data()
                
    def fit(self, epochs, batch_size, batch_per_epoch, dis_lr, sup_ds: CustomDataSet, full_ds: CustomDataSet, optim: Type[optim.Optimizer], PATH = ".", save_best = False, grad_clip = True):
        optimizerD = optim(self.discriminator.parameters(), lr = dis_lr)
        n_sup = len(sup_ds) 
        n_data = len(full_ds)


        for epoch in (range(epochs)):
            print(f"epoch: {epoch}")
            self.discriminator.train()
            with open('acc.txt', 'w') as F: 
                for i in (range(batch_per_epoch)):
                    sup_images, labels = random_split(sup_ds, [batch_size, n_sup - batch_size])[0][:]
                    C_loss = self.classifier_step(sup_images.to(self.device), labels.to(self.device))
                    full_images = random_split(full_ds, [batch_size, n_data-batch_size])[0][:]
                
                    loss_real = self.discriminator_real_step(full_images.to(self.device))
                    loss_fake = self.discriminator_fake_step(batch_size)
                    D_loss = (loss_real + loss_fake)/2
                    

                    loss = C_loss+ D_loss/2

                    loss.backward()
                    if grad_clip: 
                        nn.utils.clip_grad_value_(self.discriminator.parameters(), 0.1)

                    optimizerD.step()
                    tqdm.write(f'loss: {loss.detach().item()}', end = "\r")
                
                self.discriminator.eval()
                # F.write(f"\n accuracy: {self.accuracy(test_dl)}")
                tqdm.write(f'accuracy: {self.accuracy(test_dl)}', end = "\r")
                    # self.classifier_step(sup_images, labels, optimizerD)
                    # self.discriminator_real_step(full_images, optimizerD)
                    # self.discriminator_fake_step(_full_batch_size, optimizerD)
        

In [18]:
GANSSL = GAN(100, channels, 10, config.DEVICE)

In [19]:
GANSSL.load_gen_state_dict(f"DCGAN/{config.USED_DATA}/netG_epoch_024.pth")

In [20]:
sup_ds = CustomDataSet(X_sup, y_sup, train_tfm)
full_ds = CustomDataSet(X_full, None, train_tfm)

In [None]:
GANSSL.fit(config.EPOCHS, 64, 50, 1e-5, sup_ds, full_ds, optim.RMSprop, grad_clip=False)

In [None]:
corrected = 0
for b in tqdm(test_dl):
    images, y = b
    outs = GANSSL.discriminator.forward(images)
    outs = torch.argmax(outs[:, :-1], dim=1)
    corrected += (outs == y).sum().item()

100%|██████████| 20/20 [00:01<00:00, 13.98it/s]


In [None]:
corrected / test_dl.num_data()

0.7086

In [21]:
class SelfTraining: 
    def __init__(self, model: GAN, X_sup: Tensor, y_sup: Tensor, X: Tensor, X_unsup: Tensor, test_dl: DeviceDataLoader, transform, num_rounds): 
        '''
            Input of self-training model:
            model: classifier
            num_rounds: number of self_training rounds
            sup_samples: number of supervised samples
        '''
        self.model = model 
        self.X_sup = X_sup
        self.y_sup = y_sup
        self.X = X
        self.X_unsup = X_unsup
        self.transform = transform
        self.test_dataloader = test_dl
        
        self.num_rounds = num_rounds
    

    def CalDisagreement(self, h1: GAN, h2: GAN, dataset: CustomDataSet): 
        '''
            Calculate disagreement between teacher model and student model
            h1: Teacher model 
            h2: Student model
        '''
        disagreement = 0
        for x, _ in dataset: 
            disagreement += (torch.argmax(h1.discriminator(x.unsqueeze(0))) == torch.argmax(h2.discriminator(x.unsqueeze(0))))
        
        return disagreement/len(dataset)

    def random_sampling(self, sample_fraction: float, dataset: CustomDataSet, n: int): 
        dataset_set: list[CustomDataSet] = []
        for _ in range(n): 
            
            idx = random.sample(range(0, len(dataset)), int(len(dataset)*sample_fraction))
            data_X = dataset.x[idx]
            data_y = dataset.y[idx]

            dataset_set.append(CustomDataSet(data_X, data_y))
        return dataset_set
    
    def random_sampling(self, idx, sample_fraction, n):
        subsets_idx = []
        for _ in range(n):
            subset_idx = random.sample(idx, int(len(idx) * sample_fraction))
            subsets_idx.append(subset_idx)
        
        return subsets_idx

    def selfTraining(self, epochs, lr, batch_size: int, batch_per_epoch,  sample_fraction: float, n: int, opt_func: Type[optim.Optimizer] = optim.Adam, sched = True, PATH = ".", save_best = False, device = 'cpu'): 
        teacher_model = copy.deepcopy(self.model)
        full_ds = CustomDataSet(self.X, None)
        
        for _ in range(self.num_rounds): 
            sup_ds = CustomDataSet(self.X_sup, self.y_sup)
            
            student_model = copy.deepcopy(teacher_model) 
            student_model.fit(epochs, batch_size, batch_per_epoch, lr, sup_ds, full_ds, opt_func)

            print("start")
           
            unsup_ds = CustomDataSet(self.X_unsup, None, self.transform)
            unsup_dl = CreateDataLoader(unsup_ds, batch_size=1, device = device)
            confidence = Tensor()
            for unsup_imgs in tqdm(unsup_dl):
            # for i in tqdm(range(0, len(unsup_ds), batch_size)):  
                # unsup_img = unsup_ds[i:min(len(unsup_ds), i+batch_size)]
                prob = F.softmax(student_model.discriminator(unsup_imgs).cpu()[:, :-1], dim=1)
                # add 0.1 into label in order to prevent collapsing at label 0
                labels = Tensor([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 5.1, 7.1, 8.1, 9.1])
                a = torch.matmul(prob, labels)
                confidence = torch.cat((confidence, a))
            
            print(confidence.shape)
            print("threshold")
            break
            threshold = np.median(np.array(list(d.values())))
            threshold_idx = []

            for i in range(len(d)):
                if d[i] > threshold:
                    threshold_idx.append(i)
            
            print("sampling")
            # randomly sample sample_fraction of threshold_ds
            dataset_idx = self.random_sampling(threshold_idx, sample_fraction=sample_fraction, n=n)

            max = 0

            # for subset_idx in dataset_idx:
            #     print(':)')
            #     model = model = Classifier(channels, n_classes).to(config.DEVICE)
            #     model.train()
            #     # calculate U\U[i]
            #     unlabel = self.X_unsup
            #     unlabel_i = self.X_unsup[subset_idx]

            #     counterpart_idx = []

            #     for i in range(len(self.X_unsup)):
            #         if i not in subset_idx:
            #             counterpart_idx.append(i)
                
            #     counterpart = self.X_unsup[counterpart_idx]

            #     y_counterpart = torch.argmax(teacher_model(counterpart), dim=1)  # shape: len(counterpart) x 1
                
            #     print(y_counterpart)
            # break
            # '''
            for I in range(len(dataset_idx)): 
                model = Classifier(channels, n_classes).to(config.DEVICE)
                model.train()
                # calculate U\U[i]
                unlabel = self.X_unsup
                unlabel_i = dataset_set[I].x

                counterpart = Tensor().type_as(unlabel)
                
                # '''debugging'''
                # testing = True
                # print(threshold_X.shape) 
                # for i in range(60): 
                #     for j in range(i, 61): 
                #         if torch.equal(threshold_X[i] ,threshold_X[j]): 
                #             testing = False
                
                # print(testing)
                # break

                for i in range(unlabel.shape[0]): 
                    check = True
                    for j in range(unlabel_i.shape[0]): 
                       if torch.equal(unlabel[i], unlabel_i[j]):
                           check = False
                           break
                    if check:
                       counterpart = torch.cat((counterpart, unlabel[i].unsqueeze(0)))


                # generate label of data in U\U[i] by teacher_model classifier 
                y_counterpart = Tensor().type_as(unlabel)
                for x in counterpart: 
                    y_counterpart = torch.cat((y_counterpart, teacher_model(x.unsqueeze(0)).unsqueeze(0)))

                X_data = torch.cat((self.X_sup, unlabel_i, counterpart))
                y_data = torch.cat((self.y_sup, dataset_set[I].y ,y_counterpart)).to(dtype=torch.int)
                dl = CreateDataLoader(X_data, y_data, config.BATCH_SIZE, train_tfm, config.DEVICE)
               
                model.fit(config.EPOCHS, config.LEARNING_RATE, dl, test_dl, opt_func=optim.Adam, save_best=False)
                if self.CalDisagreement(student_model.classify, model, unlabeled_dataset) > max: 
                    max = self.CalDisagreement(student_model, model, unlabeled_dataset)
                    best = dataset_set[I]
            
            labeled_dataset.x = torch.cat((labeled_dataset.x, best.x))
            labeled_dataset.y = torch.cat((labeled_dataset.y, best.y))
            # remove sample from best dataset from unlabled dataset 
            for i in range(self.X_unsup.shape[0]): 
                check = True
                for j in range(best.x.shape[0]): 
                    if torch.equal(self.X_unsup[i], best.x[j]):
                        check = False
                        break
                    if not check:
                        self.X_unsup = torch.cat((self.X_unsup[:i], self.X_unsup[:i+1]))
                        unlabeled_dataset.y = torch.cat((unlabeled_dataset.y[:i], unlabeled_dataset.y[:i+1]))
            # reassign teacher model 
            teacher_model = student_model
        # return best model 
        self.model = teacher_model

In [22]:
selftraining = SelfTraining(GANSSL, X_sup, y_sup.to(dtype = torch.uint8), X_full, X_unsup, test_dl, transform=train_tfm, num_rounds=3)

In [23]:
dl = selftraining.selfTraining(1, 1e-5, 64, 50, 0.1, 3, opt_func=optim.RMSprop, device=config.DEVICE)

epoch: 0
loss: 0.05876884236931801

100%|██████████| 20/20 [00:01<00:00, 15.17it/s]


startacy: 0.2291


  8%|▊         | 4994/59900 [00:14<02:41, 340.37it/s]


RuntimeError: CUDA driver error: out of memory

In [24]:
message = torch.cuda.memory_summary(device=None, abbreviated=True)
with open('check.txt', 'w') as f: 
    f.write(message)