In [34]:
import torch
from torch import nn

import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, Sampler, DataLoader, Subset, SubsetRandomSampler, BatchSampler
from sklearn.model_selection import train_test_split
import numpy as np
import torch.nn.functional as F

import matplotlib.pyplot as plt 
from helper_functions_2 import softmax_kl_loss, sigmoid_rampup, get_current_consistency_weight, linear_rampup, grouper, relabel_dataset
import time

from math import sqrt

In [35]:
train_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=None
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
    target_transform=None
)

class_names = train_data.classes
class_names

['T-shirt/top',
 'Trouser',
 'Pullover',
 'Dress',
 'Coat',
 'Sandal',
 'Shirt',
 'Sneaker',
 'Bag',
 'Ankle boot']

In [36]:
NO_LABEL = -1
BATCH_SIZE = 16

class Arguments():
    def __init__(self, momentum, weight_decay, nesterov, epochs:int, consistency, exclude_unlabeled:bool, batch_size=64, labeled_batch_size=32, consistency_type='kl', lr=0.01, initial_lr=0.001, lr_rampup = 10, ema_decay=0.999, consistency_rampup=4):
        super().__init__()

        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.nesterov = nesterov
        self.epochs = epochs
        self.consistency_type = consistency_type
        self.initial_lr = initial_lr
        self.lr_rampup = lr_rampup
        self.consistency = consistency
        self.ema_decay = ema_decay
        self.labeled_batch_size = labeled_batch_size
        self.exclude_unlabeled = exclude_unlabeled
        self.batch_size = batch_size
        self.consistency_rampup = consistency_rampup

args = Arguments(lr=0.01, momentum=0, weight_decay=0, nesterov=False, epochs=4, exclude_unlabeled=True, consistency=0, batch_size=64, labeled_batch_size=64)

In [37]:
""" train_loader = torch.utils.data.DataLoader(
        dataset= train_data,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        drop_last=False)

eval_loader = torch.utils.data.DataLoader(
    dataset=test_data,
    batch_size=args.batch_size,
    shuffle=False,
    pin_memory=True,
    drop_last=False) """

' train_loader = torch.utils.data.DataLoader(\n        dataset= train_data,\n        batch_size=args.batch_size,\n        shuffle=False,\n        pin_memory=True,\n        drop_last=False)\n\neval_loader = torch.utils.data.DataLoader(\n    dataset=test_data,\n    batch_size=args.batch_size,\n    shuffle=False,\n    pin_memory=True,\n    drop_last=False) '

In [38]:
""" class CustomTrainDataset(Dataset):

    def  __init__(self, train_data, labels=None):
        self.data = train_data
        self.labels = labels

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img = self.data[idx]
        img = img.numpy().astype(np.uint8).astype(np.uint8)
        label = self.labels[idx]

        return img, label """

' class CustomTrainDataset(Dataset):\n\n    def  __init__(self, train_data, labels=None):\n        self.data = train_data\n        self.labels = labels\n\n    def __len__(self):\n        return len(self.data)\n    \n    def __getitem__(self, idx):\n        img = self.data[idx]\n        img = img.numpy().astype(np.uint8).astype(np.uint8)\n        label = self.labels[idx]\n\n        return img, label '

In [39]:
""" train_dataset = CustomTrainDataset(train_data.data, train_data.targets)
train_loader = DataLoader(train_dataset, BATCH_SIZE, False)
X, y = next(iter(train_loader))
X, y = next(iter(train_loader))
N = int(sqrt(BATCH_SIZE))

fig, axs = plt.subplots(N, N, figsize=(15,15))
for i in range(len(y)):
    img = X[i]
    label = y[i]
    ax = fig.add_subplot(N, N, i+1)
    ax.imshow(img.reshape(28, 28),cmap="gray")
    ax.set_xticks([]) #set empty label for x axis
    ax.set_yticks([]) #set empty label for y axis
    ax.set_title(f"{class_names[label]}")

plt.show() """

' train_dataset = CustomTrainDataset(train_data.data, train_data.targets)\ntrain_loader = DataLoader(train_dataset, BATCH_SIZE, False)\nX, y = next(iter(train_loader))\nX, y = next(iter(train_loader))\nN = int(sqrt(BATCH_SIZE))\n\nfig, axs = plt.subplots(N, N, figsize=(15,15))\nfor i in range(len(y)):\n    img = X[i]\n    label = y[i]\n    ax = fig.add_subplot(N, N, i+1)\n    ax.imshow(img.reshape(28, 28),cmap="gray")\n    ax.set_xticks([]) #set empty label for x axis\n    ax.set_yticks([]) #set empty label for y axis\n    ax.set_title(f"{class_names[label]}")\n\nplt.show() '

In [40]:
split = 1 - args.labeled_batch_size/args.batch_size
NO_LABEL = -1
BATCH_SIZE = args.batch_size

""" # IS this the problem 
if(args.exclude_unlabeled == False):
    labelled_train_data, unlabelled_train_data, labels_train, unlabels_train = train_test_split(train_data.data, train_data.targets, stratify=train_data.targets, test_size=split)
else: # This is not the problem
    labelled_train_data = train_data.data
    labels_train = train_data.targets """


' # IS this the problem \nif(args.exclude_unlabeled == False):\n    labelled_train_data, unlabelled_train_data, labels_train, unlabels_train = train_test_split(train_data.data, train_data.targets, stratify=train_data.targets, test_size=split)\nelse: # This is not the problem\n    labelled_train_data = train_data.data\n    labels_train = train_data.targets '

In [41]:
# TEST CELL

train_set_size = int(len(train_data) * 0.8)
valid_set_size = len(train_data) - train_set_size
labelled_data, unlabelled_data = torch.utils.data.random_split(train_data, [train_set_size, valid_set_size])

data_0 = []
labels_0 = []
data_1 = []

data_test = []
labels_test = []

for (X, y) in labelled_data:
    data_0.append(X)
    labels_0.append(y)

for (X, y) in test_data:
    data_test.append(X)
    labels_test.append(y)

for (X, y) in unlabelled_data:
    data_1.append(X)

In [42]:
class CustomTrainDataset(Dataset):
    def __init__(self, train_data, labels=None):
        self.base_data = train_data
        if(labels == None):
            self.labels = torch.tensor([-1 for i in range(len(self.base_data))],dtype=torch.int64)
        else:
            self.labels = labels
        super().__init__()

    def __len__(self):
        return len(self.base_data)

    def __getitem__(self, idx):
        img = self.base_data[idx]
        # img = img.unsqueeze(dim=0)
        label = self.labels[idx]
        return img, label

labelled_dataset = CustomTrainDataset(train_data=data_0, labels=labels_0)
if(args.exclude_unlabeled == False):
    unlabelled_dataset = CustomTrainDataset(train_data=data_1)
    train_dataset = torch.utils.data.ConcatDataset([labelled_dataset, unlabelled_dataset])
else:
    train_dataset = labelled_dataset

test_dataset = CustomTrainDataset(train_data=data_test, labels=labels_test)

In [43]:
""" # TEST CELL

labelled_train_data, unlabelled_train_data, labels_train, unlabels_train = train_test_split(train_data.data, train_data.targets, stratify=train_data.targets, test_size=0.2)
labelled_dataset = CustomTrainDataset(train_data=labelled_train_data, labels=labels_train)
unlabelled_dataset = CustomTrainDataset(train_data=unlabelled_train_data)
train_dataset = torch.utils.data.ConcatDataset([labelled_dataset, unlabelled_dataset]) """

' # TEST CELL\n\nlabelled_train_data, unlabelled_train_data, labels_train, unlabels_train = train_test_split(train_data.data, train_data.targets, stratify=train_data.targets, test_size=0.2)\nlabelled_dataset = CustomTrainDataset(train_data=labelled_train_data, labels=labels_train)\nunlabelled_dataset = CustomTrainDataset(train_data=unlabelled_train_data)\ntrain_dataset = torch.utils.data.ConcatDataset([labelled_dataset, unlabelled_dataset]) '

In [44]:
""" # TEST CELL

# train_loader = DataLoader(unlabelled_dataset, BATCH_SIZE, False)
X, y = next(iter(train_loader))
# X, y = next(iter(eval_loader))
N = int(sqrt(BATCH_SIZE))

fig, axs = plt.subplots(N, N, figsize=(17,17))
for i in range(len(y)):
    img = X[i]
    label = y[i]
    ax = fig.add_subplot(N, N, i+1)
    ax.imshow(img.reshape(28, 28),cmap="gray")
    ax.set_xticks([]) #set empty label for x axis
    ax.set_yticks([]) #set empty label for y axis
    if label != -1:
        title = f"{class_names[label]}"
    else:
        title = "NO_LABEL"
    ax.set_title(title)

plt.show() """

' # TEST CELL\n\n# train_loader = DataLoader(unlabelled_dataset, BATCH_SIZE, False)\nX, y = next(iter(train_loader))\n# X, y = next(iter(eval_loader))\nN = int(sqrt(BATCH_SIZE))\n\nfig, axs = plt.subplots(N, N, figsize=(17,17))\nfor i in range(len(y)):\n    img = X[i]\n    label = y[i]\n    ax = fig.add_subplot(N, N, i+1)\n    ax.imshow(img.reshape(28, 28),cmap="gray")\n    ax.set_xticks([]) #set empty label for x axis\n    ax.set_yticks([]) #set empty label for y axis\n    if label != -1:\n        title = f"{class_names[label]}"\n    else:\n        title = "NO_LABEL"\n    ax.set_title(title)\n\nplt.show() '

In [45]:
class TwoStreamBatchSampler(Sampler):
     
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0
        
    def __iter__(self):
            primary_iter = np.random.permutation(self.primary_indices)
            secondary_iter = np.random.permutation(self.secondary_indices)
            return (
                primary_batch + secondary_batch
                for (primary_batch, secondary_batch)
                in  zip(grouper(primary_iter, self.primary_batch_size),
                        grouper(secondary_iter, self.secondary_batch_size)))

    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size


def create_data_loaders(train_dataset, test_dataset, args):
    
    labeled_idxs, unlabeled_idxs = relabel_dataset(dataset=train_dataset)

    if args.exclude_unlabeled:
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
        
    else: 
        batch_sampler = TwoStreamBatchSampler(unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)


    train_loader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=batch_sampler,
                                                pin_memory=True)

    eval_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        drop_last=False)
    
    return train_loader, eval_loader

train_loader, eval_loader = create_data_loaders(labelled_dataset, test_dataset, args)
    
# labelled_dataset = CustomTrainDataset(train_data=labelled_train_data, labels=labels_train)
# train_loader_custom = torch.utils.data.DataLoader(
#         dataset= labelled_dataset,
#         batch_size=args.batch_size,
#         shuffle=False,
#         pin_memory=True,
#         drop_last=False) 


In [46]:
""" # TEST CELL

train_loader = torch.utils.data.DataLoader(
        dataset= train_data,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        drop_last=False)

eval_loader = torch.utils.data.DataLoader(
    dataset=test_data,
    batch_size=args.batch_size,
    shuffle=False,
    pin_memory=True,
    drop_last=False)

 """

' # TEST CELL\n\ntrain_loader = torch.utils.data.DataLoader(\n        dataset= train_data,\n        batch_size=args.batch_size,\n        shuffle=False,\n        pin_memory=True,\n        drop_last=False)\n\neval_loader = torch.utils.data.DataLoader(\n    dataset=test_data,\n    batch_size=args.batch_size,\n    shuffle=False,\n    pin_memory=True,\n    drop_last=False)\n\n '

In [47]:
class FashionMNSITModel_V2(nn.Module):
    def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
        super().__init__()

        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape,
                      out_channels=hidden_units,
                      kernel_size=(3,3),
                      padding=1,
                      stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=(3,3),
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2))
        )

        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=(3,3),
                      padding=1,
                      stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=(3,3),
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2))
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=hidden_units*7*7,
                      out_features=output_shape)

        )

    def forward(self, x):
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        x = self.classifier(x)
        return x
    
torch.manual_seed(42)

model_2 = FashionMNSITModel_V2(
    input_shape=1,
    hidden_units=10,
    output_shape=len(class_names)
)

def create_models(input_shape: int, hidden_units: int, output_shape:int, ema=False):

    model = FashionMNSITModel_V2(
    input_shape=input_shape,
    hidden_units=hidden_units,
    output_shape=output_shape)

    if ema:
        for param in model.parameters():
            param.detach_()

    return model

In [48]:
model_student = create_models(input_shape=1, hidden_units=10, output_shape=len(class_names))
model_teacher = create_models(input_shape=1, hidden_units=10, output_shape=len(class_names), ema=True)

optimizer = torch.optim.SGD(model_student.parameters(), args.lr)
optimizer_2 = torch.optim.SGD(params=model_2.parameters(), lr=0.01)

In [49]:
def accuracy_fn(output, target, args, train=False):

    y_preds = torch.argmax(output, dim=1)

    if(train):
        y_preds = y_preds[:args.labeled_batch_size]
        target = target[:args.labeled_batch_size]

    res = sum(torch.eq(y_preds,target)).item() / len(output)
    
    return res

def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

def adjust_learning_rate(optimizer, epoch, batch_num, batches_in_epoch, args):
    lr = args.lr
    epoch = epoch + batch_num / batches_in_epoch

    lr = linear_rampup(epoch, args.lr_rampup) * (args.lr - args.initial_lr) + args.initial_lr

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [50]:
def train(train_loader, model_student, model_teacher, optimizer, epoch, args):
    global global_step

    # Is this the problem? maybe. consistency = 0. so we dont know
    class_criterion = nn.CrossEntropyLoss(reduction="sum",ignore_index=NO_LABEL)
    if args.consistency_type == 'kl':
         consistency_criterion = softmax_kl_loss

    model_student.train()
    model_teacher.train() # is this the problem - No

    start = time.time()

    for batch, (X, y) in enumerate(train_loader):

        # Adjust learning rate for minibatach - read more on this and minibatch/batch training
        # Is this the problem? maybe. consistency = 0. so we dont know
        if(args.exclude_unlabeled == False):
            adjust_learning_rate(optimizer, epoch, batch, len(train_loader), args)

        # Add noise to student and teacher inputs: ideally data should be augmented in the data loader. look more into it
        student_input_var = X + 0.01*torch.randn(size=X.shape) # Not the problem
        #student_input_var = student_input_var.unsqueeze(dim=1)
        teacher_input_var = X + 0.01*torch.randn(size=X.shape) # Not the problem
        #teacher_input_var = teacher_input_var.unsqueeze(dim=1)


        minibatch_size = len(y)

        # Forward Pass
        student_out = model_student(student_input_var)
        teacher_out = model_teacher(teacher_input_var) # is this the problem - No
        
        teacher_logit = teacher_out.detach().data # is this the problem - No

        ## Evaluate the Loss
        classification_loss = class_criterion(student_out, y) / minibatch_size
        teacher_classification_loss = class_criterion(teacher_logit, y) / minibatch_size # is this the problem - No
        
        # is this the problem - No (but consistency right now is 0, we will get back to it)
        if args.consistency:
            consistency_weight = get_current_consistency_weight(epoch, args)
            consistency_loss = consistency_weight * consistency_criterion(student_out, teacher_logit) / minibatch_size # is this the problem - No
        else:
            consistency_loss = 0

        loss = classification_loss + consistency_loss # is this the problem - No

        #print("+++++++++++++++++++++++++")
        #print("Loss")
        #print(loss)
        #print("consistency Loss")
        #print(consistency_loss)
        #print("+++++++++++++++++++++++++")
        #print(xxx)
        ## The usual
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        
        update_ema_variables(model_student, model_teacher, args.ema_decay, global_step) # is this the problem

        end = time.time()

        num = int(len(train_loader)/2)
        if batch % num == 0:
            print(f"Looked at {batch * len(X)}/{len(train_loader.dataset)} samples")

    print(f"Training Loss = {loss}, Consistency Loss = {consistency_loss}")



In [51]:
def validate(eval_loader, model, args):

    test_acc = 0
    class_loss = 0
    class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL)
    model.eval()

    with torch.inference_mode():
        for i, (input, target) in enumerate(eval_loader):

            minibatch_size = len(target)
            output = model(input)
            softmax1 = F.softmax(output, dim=1)
            class_loss += class_criterion(output, target) / minibatch_size

            test_acc += accuracy_fn(output, target, args)
        
        test_acc /= len(eval_loader)
        class_loss /= len(eval_loader)
        
    return test_acc, class_loss
        

In [52]:
from tqdm.auto import tqdm

student_accuracy = []
teacher_accuracy = []

global_step = 0
for epoch in tqdm(range(5)):
    train(train_loader, model_student, model_teacher, optimizer, epoch, args)
    s_acc, s_loss = validate(eval_loader, model_student, args)
    student_accuracy.append(s_acc)
    t_acc, t_loss = validate(eval_loader, model_teacher, args)
    teacher_accuracy.append(t_acc)
    
    print(f"Student Accuracy = {s_acc*100}, Teacher accuracy = {t_acc*100}")

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:1050.)


Looked at 0/48000 samples
Looked at 24000/48000 samples
Training Loss = 0.9125328660011292, Consistency Loss = 0


 20%|██        | 1/5 [00:27<01:49, 27.42s/it]

Student Accuracy = 65.406050955414, Teacher accuracy = 38.51512738853503
Looked at 0/48000 samples
Looked at 24000/48000 samples
Training Loss = 0.6578415632247925, Consistency Loss = 0


 40%|████      | 2/5 [00:54<01:21, 27.33s/it]

Student Accuracy = 74.20382165605095, Teacher accuracy = 72.73089171974523
Looked at 0/48000 samples
Looked at 24000/48000 samples
Training Loss = 0.5941431522369385, Consistency Loss = 0


 60%|██████    | 3/5 [01:22<00:55, 27.56s/it]

Student Accuracy = 76.93073248407643, Teacher accuracy = 77.1297770700637
Looked at 0/48000 samples
Looked at 24000/48000 samples
Training Loss = 0.3464721739292145, Consistency Loss = 0


 80%|████████  | 4/5 [01:50<00:27, 27.62s/it]

Student Accuracy = 81.09076433121018, Teacher accuracy = 79.04060509554141
Looked at 0/48000 samples
Looked at 24000/48000 samples
Training Loss = 0.42500215768814087, Consistency Loss = 0


100%|██████████| 5/5 [02:18<00:00, 27.62s/it]

Student Accuracy = 81.67794585987261, Teacher accuracy = 81.031050955414





In [40]:
def train_normal(train_loader, model, optimizer, epoch, args):
    class_criterion = nn.CrossEntropyLoss()
    
    model.train()

    for batch, (X,y) in enumerate(train_loader):
        
        output = model(X)
        loss = class_criterion(output, y)
        
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        num = int(len(train_loader)/2)
        if batch % num == 0:
            print(f"Looked at {batch * len(X)}/{len(train_loader.dataset)} samples")

    print(f"Training Loss = {loss}")

In [41]:
from tqdm.auto import tqdm

student_accuracy = []
teacher_accuracy = []

global_step = 0
for epoch in tqdm(range(args.epochs)):
    train_normal(train_loader, model_2, optimizer_2, epoch, args)
    s_acc, s_loss = validate(eval_loader, model_2, args)
    student_accuracy.append(s_acc)
    t_acc, t_loss = validate(eval_loader, model_2, args)
    teacher_accuracy.append(t_acc)
    
    print(f"Student Accuracy = {s_acc*100}, Teacher accuracy = {t_acc*100}")

  0%|          | 0/4 [00:00<?, ?it/s]

Looked at 0/60000 samples
Looked at 30016/60000 samples
Training Loss = 0.8393805623054504


 25%|██▌       | 1/4 [00:36<01:48, 36.11s/it]

Student Accuracy = 65.23686305732484, Teacher accuracy = 65.23686305732484
Looked at 0/60000 samples
Looked at 30016/60000 samples
Training Loss = 0.7108218669891357


 50%|█████     | 2/4 [01:12<01:12, 36.14s/it]

Student Accuracy = 73.218550955414, Teacher accuracy = 73.218550955414
Looked at 0/60000 samples
Looked at 30016/60000 samples
Training Loss = 0.625800609588623


 75%|███████▌  | 3/4 [01:48<00:36, 36.29s/it]

Student Accuracy = 76.26393312101911, Teacher accuracy = 76.26393312101911
Looked at 0/60000 samples
Looked at 30016/60000 samples
Training Loss = 0.5647347569465637


100%|██████████| 4/4 [02:24<00:00, 36.12s/it]

Student Accuracy = 78.6624203821656, Teacher accuracy = 78.6624203821656





In [84]:
args.exclude_unlabeled == False

False