In [20]:
import torch
from torch import optim
from torch import Tensor
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torchvision.utils import make_grid
import torchvision.transforms as tt
import torch.nn as nn

from utils import *
import config
import random

from typing import Type

from Classify import Classifier

from PIL import Image

from tqdm import tqdm

In [2]:
print_config()

RANDOM_SEED   :  11042004
DATA_DIR      :    ./data
USED_DATA     :   CIFAR10
NUM_LABELLED  :      1000
DEVICE        :    cuda:0
EPOCHS        :        50
BATCH_SIZE    :       512
LEARNING_RATE :      0.01
SCHED         :      True
GAN_BATCH_SIZE:       128


In [3]:
set_random_seed(config.RANDOM_SEED)
random.seed(config.RANDOM_SEED)

Setting seeds ...... 



In [4]:
name = "CNN"

In [5]:
PATH = get_PATH(name)
PATH

'CIFAR10/CNN/_1000'

In [6]:
class ToPILImage():
    def __init__(self, mode=None):
        self.tt = tt.ToPILImage(mode)

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        if type(pic) == Image.Image:
            return pic
        else:
            return self.tt(pic)

    def __repr__(self) -> str:
        return self.tt.__repr__()

In [7]:
mean = [0.5]*3
std = [0.5]*3
if config.USED_DATA == "CIFAR10":

	train_tfm = tt.Compose([
		ToPILImage(),
		tt.RandomCrop(32, padding=4, padding_mode='edge'),
		tt.RandomHorizontalFlip(),
		tt.ToTensor(),
		tt.Normalize(mean, std, inplace=True)
	])

if config.USED_DATA == "MNIST":
	train_tfm = tt.Compose([
		tt.Normalize(mean, std, inplace=True)
	])

test_tfm = tt.Compose([
	ToPILImage(),
	tt.ToTensor(),
	tt.Normalize(mean, std)
])

In [8]:
train_ds, test_ds, classes = load_data(train_tfm, test_tfm)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
X_full = train_ds.data
y_full = train_ds.targets

In [10]:
classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [11]:
n_classes = len(classes)
channels = train_ds.data[0].shape[2]
n_classes, channels

(10, 3)

In [12]:
# test dataloader
test_dl = CreateDataLoader(test_ds.data, test_ds.targets, 512, test_tfm, device=config.DEVICE)

In [13]:
X_sup, y_sup, X_unsup, _ = supervised_samples(Tensor(X_full), Tensor(y_full), config.NUM_LABELLED, n_classes, get_unsup=True)

In [14]:
X_sup = X_sup.numpy()
y_sup = y_sup.numpy()
X_unsup = X_unsup.numpy()

In [15]:
model = Classifier(channels, n_classes).to(config.DEVICE, non_blocking=True)
model

Classifier(
  (conv): ConvModel(
    (initial): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (Conv): Sequential(
      (0): ConvBn(
        (Conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (1): ConvBn(
        (Conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (2): ConvBn(
        (Conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(512, eps=1e-05

In [16]:
import copy 
from torch import ByteTensor

In [21]:
class SelfTraining: 
    def __init__(self, model: Classifier, X_sup: Tensor, y_sup: Tensor, X_unsup: Tensor, test_dl: DeviceDataLoader, transform, num_rounds): 
        '''
            Input of self-training model:
            model: classifier
            num_rounds: number of self_training rounds
            sup_samples: number of supervised samples
        '''
        self.model = model 
        self.X_sup = X_sup
        self.y_sup = y_sup
        self.X_unsup = X_unsup
        self.transform = transform
        self.test_dataloader = test_dl
        
        self.num_rounds = num_rounds
    

    def CalDisagreement(self, h1: Classifier, h2: Classifier, dataset: CustomDataSet): 
        '''
            Calculate disagreement between teacher model and student model
            h1: Teacher model 
            h2: Student model
        '''
        disagreement = 0
        for x, _ in dataset: 
            disagreement += (torch.argmax(h1(x.unsqueeze(0))) == torch.argmax(h2(x.unsqueeze(0))))
        
        return disagreement/len(dataset)

    def random_sampling(self, sample_fraction: float, dataset: CustomDataSet, n: int): 
        dataset_set: list[CustomDataSet] = []
        for _ in range(n): 
            
            idx = random.sample(range(0, len(dataset)), int(len(dataset)*sample_fraction))
            data_X = dataset.x[idx]
            data_y = dataset.y[idx]

            dataset_set.append(CustomDataSet(data_X, data_y))
        return dataset_set
    
    def random_sampling(self, idx, sample_fraction, n):
        subsets_idx = []
        for _ in range(n):
            subset_idx = random.sample(idx, int(len(idx) * sample_fraction))
            subsets_idx.append(subset_idx)
        
        return subsets_idx

    def selfTraining(self, epochs, lr, batch_size: int, sample_fraction: float, n: int, opt_func: Type[optim.Optimizer] = optim.Adam, sched = True, PATH = ".", save_best = False, device = 'cpu'): 
        teacher_model = copy.deepcopy(self.model)
        for _ in range(self.num_rounds): 
            sup_dl = CreateDataLoader(self.X_sup, self.y_sup, batch_size=batch_size, transform=self.transform, device=device)
            student_model = copy.deepcopy(teacher_model) 
            student_model.fit(epochs, lr, sup_dl, self.test_dataloader, opt_func=opt_func, sched=sched, PATH=PATH, save_best=save_best)
            d=dict()
            labels = []

            print("start")
            X = Tensor(X_unsup).permute(0, 2, 1, 3).to(config.DEVICE)
            c_labels = student_model(X)
            labels = torch.argmax(c_labels, dim=1).numpy()
            a = 
            
            for i, x in tqdm(enumerate(self.X_unsup)): 
                # print(x.get_device())
                x = Tensor(x).permute(2, 1, 0).to(config.DEVICE)
                c_labels = student_model(x.unsqueeze(0)) # (1 x 10)
                # print(c_labels[0])
                labels.append(torch.argmax(c_labels[0]))
                a = torch.sum(Tensor([j*c_labels[0][j] for j in range(len(c_labels[0]))])).item()
                
                d[i] = a
            
            print("threshold")
            
            threshold = np.median(np.array(list(d.values())))
            threshold_idx = []

            for i in range(len(d)):
                if d[i] > threshold:
                    threshold_idx.append(i)
            
            print("sampling")
            # randomly sample sample_fraction of threshold_ds
            dataset_idx = self.random_sampling(threshold_idx, sample_fraction=sample_fraction, n=n)

            max = 0

            for subset_idx in dataset_idx:
                print(':)')
                model = model = Classifier(channels, n_classes).to(config.DEVICE)
                model.train()
                # calculate U\U[i]
                unlabel = self.X_unsup
                unlabel_i = self.X_unsup[subset_idx]

                counterpart_idx = []

                for i in range(len(self.X_unsup)):
                    if i not in subset_idx:
                        counterpart_idx.append(i)
                
                counterpart = self.X_unsup[counterpart_idx]

                y_counterpart = torch.argmax(teacher_model(counterpart), dim=1)  # shape: len(counterpart) x 1
                
                print(y_counterpart)
            break
            # '''
            for I in range(len(dataset_idx)): 
                model = Classifier(channels, n_classes).to(config.DEVICE)
                model.train()
                # calculate U\U[i]
                unlabel = self.X_unsup
                unlabel_i = dataset_set[I].x

                counterpart = Tensor().type_as(unlabel)
                
                # '''debugging'''
                # testing = True
                # print(threshold_X.shape) 
                # for i in range(60): 
                #     for j in range(i, 61): 
                #         if torch.equal(threshold_X[i] ,threshold_X[j]): 
                #             testing = False
                
                # print(testing)
                # break

                for i in range(unlabel.shape[0]): 
                    check = True
                    for j in range(unlabel_i.shape[0]): 
                       if torch.equal(unlabel[i], unlabel_i[j]):
                           check = False
                           break
                    if check:
                       counterpart = torch.cat((counterpart, unlabel[i].unsqueeze(0)))


                # generate label of data in U\U[i] by teacher_model classifier 
                y_counterpart = Tensor().type_as(unlabel)
                for x in counterpart: 
                    y_counterpart = torch.cat((y_counterpart, teacher_model(x.unsqueeze(0)).unsqueeze(0)))

                X_data = torch.cat((self.X_sup, unlabel_i, counterpart))
                y_data = torch.cat((self.y_sup, dataset_set[I].y ,y_counterpart)).to(dtype=torch.int)
                dl = CreateDataLoader(X_data, y_data, config.BATCH_SIZE, train_tfm, config.DEVICE)
               
                model.fit(config.EPOCHS, config.LEARNING_RATE, dl, test_dl, opt_func=optim.Adam, save_best=False)
                if self.CalDisagreement(student_model.classify, model, unlabeled_dataset) > max: 
                    max = self.CalDisagreement(student_model, model, unlabeled_dataset)
                    best = dataset_set[I]
            
            labeled_dataset.x = torch.cat((labeled_dataset.x, best.x))
            labeled_dataset.y = torch.cat((labeled_dataset.y, best.y))
            # remove sample from best dataset from unlabled dataset 
            for i in range(self.X_unsup.shape[0]): 
                check = True
                for j in range(best.x.shape[0]): 
                    if torch.equal(self.X_unsup[i], best.x[j]):
                        check = False
                        break
                    if not check:
                        self.X_unsup = torch.cat((self.X_unsup[:i], self.X_unsup[:i+1]))
                        unlabeled_dataset.y = torch.cat((unlabeled_dataset.y[:i], unlabeled_dataset.y[:i+1]))
            # reassign teacher model 
            teacher_model = student_model
        # return best model 
        self.model = teacher_model

In [22]:
selftraining = SelfTraining(model, X_sup, np.array(y_sup, dtype=np.int_), X_unsup, test_dl, transform=train_tfm, num_rounds=3)

In [23]:
dl = selftraining.selfTraining(1, 0.0001, 64, 0.4, 10, device=config.DEVICE)

Epoch [0]


100%|██████████| 16/16 [00:01<00:00, 14.50it/s]


train_loss: 2.3998, val_loss: 2.4314, train_acc: 0.1730, val_acc: 0.1000, lrs: 0.0000->0.0000
start


36717it [00:52, 694.53it/s]


KeyboardInterrupt: 