In [1]:
# from google.colab import drive
# drive.mount('/content/gdrive', force_remount= True)

In [23]:
!pip install torch_utils



In [24]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
import torch.nn.functional as F
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import gc
import math
from torch_utils import  AverageMeter
from torch.utils.tensorboard import SummaryWriter

plt.ion()   # interactive mode

In [25]:
def displayImages(images, title1="Original", title2="Augmented", labels=None, augmented_images=None):
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']

    fig = plt.figure(figsize=(images.shape[0], images.shape[0]))
    for i in range(images.shape[0]):
        plt.subplot(5, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(images[i], cmap=plt.cm.binary)
        if labels is not None:
            plt.xlabel(class_names[int(labels[i])])
    fig.suptitle(title1, fontsize=16)

    if augmented_images is not None:
        fig2 = plt.figure(2, figsize=(augmented_images.shape[0], augmented_images.shape[0]))
        for i in range(augmented_images.shape[0]):
            plt.subplot(5, 5, i + 1)
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)
            plt.imshow(augmented_images[i], cmap=plt.cm.binary)
            if labels is not None:
                plt.xlabel(class_names[int(labels[i])])
        fig2.suptitle(title2, fontsize=16)
    plt.show()


def split_indexes(n_classes, n_labeled_per_class, n_validation, labels):
    labels = np.array(labels)
    train_labeled_indexes = []
    train_unlabeled_indexes = []
    validation_indexes = []

    for i in range(n_classes):
        indexes = np.where(labels == i)[0]
        np.random.shuffle(indexes)

        train_labeled_indexes.extend(indexes[:n_labeled_per_class])
        train_unlabeled_indexes.extend(indexes[n_labeled_per_class:-n_validation])
        validation_indexes.extend(indexes[-n_validation:])

    np.random.shuffle(train_unlabeled_indexes)
    np.random.shuffle(train_labeled_indexes)
    np.random.shuffle(validation_indexes)

    return train_labeled_indexes, train_unlabeled_indexes, validation_indexes


def to_tensor_dim(x, source='NHWC', target='NCHW'):
    return x.transpose([source.index(d) for d in target])


def normalise(X):
    mean = np.mean(X, axis=(0, 1, 2))
    std = np.std(X, axis=(0, 1, 2))
    X, mean, std = [np.array(a, np.float32) for a in (X, mean, std)]
    X -= mean
    X *= 1.0 / std
    return X


def random_flip(x):
    if np.random.rand() < 0.6:
        x = x[:, ::-1, :]
    return x.copy()


def pad(x, border=4):
    return np.pad(x, [(border, border), (border, border), (0, 0)], mode='reflect')


def pad_and_crop(x, output_size=(32, 32)):
    x = pad(x, 4)
    h, w = x.shape[:-1]
    new_h, new_w = output_size

    top = np.random.randint(0, h - new_h)
    left = np.random.randint(0, w - new_w)

    x = x[top: top + new_h, left: left + new_w, :]

    return x


def augment(X, K=1):

    # X_augmented = X
    # for k in range(K-1):
    #   X_augmented.hstack(X)
    X_augmented = np.zeros(([K]+list(X.shape)))

    # print("after hstack", X_augmented.shape)
    for k in range(K):
      for i in range(X_augmented[k].shape[0]):
            x = X[i, :]
            x = pad_and_crop(x)
            X_augmented[k, i, :] = random_flip(x)
    return X_augmented


def load_and_augment_data(dataset_name, model_params):
    """
    From datasets.CIFAR10:
        dataset.data: the image as numpy array, shape: (50000, 32, 32, 3)
        dataset.targets: labels of the images as list, len: 50000
    :return:
        augmented_labeled_X: the tensor of augmented labeled images (K=1),
                             size: (n_labeled_per_class * n_classes , 32, 32, 3)
        augmented_unlabeled_X: the tensor of augmented unlabeled images (K=2),
                             size: ((N/10 - n_labeled_per_class - n_validation) * n_classes * K , 32, 32, 3)
        train_labeled_targets: the tensor of labeled targets,
                             size = n_labeled_per_class * n_classes
        train_unlabeled_targets: the tensor of unlabeled targets,
                             size = (N/10 - n_labeled_per_class - n_validation) * n_classes
    """

    # Step 1: Set the model's hyperparameters
    n_classes = model_params["n_classes"]
    n_labeled_per_class = model_params["n_labeled_per_class"]
    n_validation = model_params["n_validation"]
    K = model_params["K"]

    # Step 2: Load the dataset
    if dataset_name == 'CIFAR10':
        dataset = datasets.CIFAR10(root="./datasets/cifar10/train", train=True, download=True)
        test_set = datasets.CIFAR10(root="./datasets/cifar10/test", train=False, download=True)
    elif dataset_name == 'SLT10':
        dataset = datasets.STL10(root="./datasets/slt10/train", split='train', download=True)
        test_set = datasets.STL10(root="./datasets/slt10/test", split='test', download=True)
    else:
        raise ValueError("Invalid dataset name")

    # Step 3: Split the indexes
    train_labeled_indexes, train_unlabeled_indexes, validation_indexes = \
        split_indexes(n_classes, n_labeled_per_class, n_validation, dataset.targets)

    # Step 4: Attract the images for training, validation
    train_labeled_images = np.take(dataset.data, train_labeled_indexes, axis=0)
    train_unlabeled_images = np.take(dataset.data, train_unlabeled_indexes, axis=0)
    target_array = np.asarray(dataset.targets)
    train_labeled_targets = np.take(target_array, train_labeled_indexes, axis=0)
    train_unlabeled_targets = np.take(target_array, train_unlabeled_indexes, axis=0)
    validation_images = np.take(dataset.data, validation_indexes, axis=0)
    validation_targets = np.take(target_array, validation_indexes, axis=0)

    # Step 5: Normalise the datasets and make the labels one-hot encoded
    train_labeled_images = normalise(train_labeled_images)
    train_unlabeled_images = normalise(train_unlabeled_images)

    test_X = normalise(test_set.data)
    test_X = to_tensor_dim(test_X)
    test_X = torch.from_numpy(test_X)

    train_labeled_targets = torch.from_numpy(train_labeled_targets)
    train_unlabeled_targets = torch.from_numpy(train_unlabeled_targets)
    validation_targets = torch.from_numpy(validation_targets)
    test_targets = torch.from_numpy(np.asarray(test_set.targets))

    # Step 6: Augment training images
    print("shape",train_unlabeled_images.shape )

    augmented_labeled_X = augment(train_labeled_images, K=1)
    augmented_unlabeled_X = augment(train_unlabeled_images, K=K)
    
    print("shape after", augmented_unlabeled_X.shape )
    
    # Take a look at some of the augmented images
    # displayImages(train_labeled_images[:10], title1="Original-Labeled", title2="Augmented-Labeled",
    #               augmented_images=augmented_labeled_X[:10], labels=train_labeled_targets[:10])
    # n_unlabeled = train_unlabeled_images.shape[0]
    # displayImages(train_unlabeled_images[:10], title1="Original-Unlabeled", title2="Augmented-Unlabeled",
    #               augmented_images=augmented_unlabeled_X[:10], labels=train_unlabeled_targets[:10])
    # displayImages(augmented_unlabeled_X[:10], title1="Augmented-Unlabeled1", title2="Augmented-Unlabeled2",
    #               augmented_images=augmented_unlabeled_X[n_unlabeled:10+n_unlabeled],
    #               labels=train_unlabeled_targets[:10])

    # Step 7: Change the dimension of np.array in oder for it to work in torch
    augmented_labeled_X = to_tensor_dim(augmented_labeled_X.reshape(train_labeled_images.shape))

    augmented_unlabeled_X_zeros = np.zeros(([K]+[augmented_unlabeled_X.shape[1]] + list(augmented_labeled_X.shape[1:])))
    #print("augmented_unlabeled_X_zeros", augmented_unlabeled_X_zeros.shape)
    for k in range(K):
      unlabeled_shape = augmented_unlabeled_X[k].reshape(train_unlabeled_images.shape)
      #print("unlabeled_shape", unlabeled_shape)
      augmented_unlabeled_X_zeros[k] = to_tensor_dim(unlabeled_shape)
    validation_images = to_tensor_dim(validation_images)
    
    return torch.from_numpy(augmented_labeled_X), torch.from_numpy(augmented_unlabeled_X_zeros), \
           train_labeled_targets, train_labeled_targets, \
           torch.from_numpy(validation_images), validation_targets, test_X, test_targets


# Defining Model Parameters and obtaining train and val sets

In [26]:
model_params = {
    "n_classes": 10,
    "n_labeled_per_class": 30,
    "n_validation": 500,
    "K": 3
}
augmented_labeled_X, augmented_unlabeled_X, train_labeled_targets, train_unlabeled_targets, \
    validation_images, validation_targets, test_images, test_targets = load_and_augment_data('CIFAR10', model_params)

print(f"augmented_labeled_X: {augmented_labeled_X.size()}")
print(f"augmented_unlabeled_X: {augmented_unlabeled_X.size()}")
print(f"train_labeled_targets: {train_labeled_targets.size()}")
print(f"train_unlabeled_targets: {train_unlabeled_targets.size()}")
print(f"validation_images: {validation_images.size()}")
print(f"validation_targets: {validation_targets.size()}")
print(f"test_targets: {test_targets.size()}")
print(f"test_X: {test_images.size()}")

# Halve the number of labeled data
# n_labeled = augmented_labeled_X.size()[0]
# new_n = int(n_labeled/2)
# augmented_labeled_X = augmented_labeled_X[:new_n]
# print(f"New length: {augmented_labeled_X.size()}")


Files already downloaded and verified
Files already downloaded and verified
shape (44700, 32, 32, 3)
shape after (3, 44700, 32, 32, 3)
augmented_labeled_X: torch.Size([300, 3, 32, 32])
augmented_unlabeled_X: torch.Size([3, 44700, 3, 32, 32])
train_labeled_targets: torch.Size([300])
train_unlabeled_targets: torch.Size([300])
validation_images: torch.Size([5000, 3, 32, 32])
validation_targets: torch.Size([5000])
test_targets: torch.Size([10000])
test_X: torch.Size([10000, 3, 32, 32])


Wide ResNet Model

In [27]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001)
        self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001)
        self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
        self.activate_before_residual = activate_before_residual
    def forward(self, x):
        if not self.equalInOut and self.activate_before_residual == True:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activate_before_residual=False):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, activate_before_residual))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001)
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)

In [28]:
class WeightEMA(object):
    def __init__(self, model, ema_model,lr, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.params = list(model.state_dict().values())
        self.ema_params = list(ema_model.state_dict().values())
        self.wd = 0.02 * lr

        for param, ema_param in zip(self.params, self.ema_params):
            param.data.copy_(ema_param.data)

    def step(self):
        one_minus_alpha = 1.0 - self.alpha
        for param, ema_param in zip(self.params, self.ema_params):
            if ema_param.dtype==torch.float32:
                ema_param.mul_(self.alpha)
                ema_param.add_(param * one_minus_alpha)
                # customized weight decay
                param.mul_(1 - self.wd)

In [29]:
batch_size_val=10
validation_images=list(torch.split(validation_images, batch_size_val))
validation_targets=list(torch.split(validation_targets, batch_size_val))

In [30]:
batch_size_test=10
test_images=list(torch.split(test_images, batch_size_val))
test_targets=list(torch.split(test_targets, batch_size_val))

#splitting into required number of batches

# Splitting data into Batches

In [31]:
# batch_size_labeled = int(augmented_labeled_X.size()[0]/num_batches)
# batch_size_unlabeled = int(augmented_unlabeled_X.size()[1]/num_batches)

In [32]:
batch_size_labeled = 30
batch_size_unlabeled = 30

In [33]:
batch_size_labeled

30

## Splitting labelled data


In [34]:
permutation = torch.randperm(augmented_labeled_X.size()[0])
labeled_batches = []
for i in range(0,augmented_labeled_X.size()[0], batch_size_labeled):
        indices = permutation[i:i+batch_size_labeled]
        batch_x, batch_y = augmented_labeled_X[indices], train_labeled_targets[indices]
        labeled_batches.append((batch_x, batch_y))

In [35]:
del augmented_labeled_X,train_labeled_targets
gc.collect()

481

# Splitting unlabelled data

In [36]:
#Splitting unlabelled into batches
k_unlabeled_batches = []
for k in range(augmented_unlabeled_X.size()[0]):
  permutation = torch.randperm(augmented_unlabeled_X.size()[1])
  unlabeled_batches = []
  for i in range(0, augmented_unlabeled_X.size()[1], batch_size_unlabeled):
          indices = permutation[i:i+batch_size_unlabeled]
          batch = augmented_unlabeled_X[k][indices]
          unlabeled_batches.append(batch)
  k_unlabeled_batches.append(unlabeled_batches)


In [37]:
del augmented_unlabeled_X,unlabeled_batches
gc.collect()

44

In [38]:
k_unlabeled_batches[0][0].size()


torch.Size([30, 3, 32, 32])

K_unlabeled_batches is a tuple of shape [k, num_batches, batch_size, 3, 32, 32]
where k is the number of augmentations.

# Defining Model

In [39]:

def label_guessing(model, data):
   return model(data)

# Example use of label_guessing function

In [40]:
# guesses = label_guessing(model_ft, k_unlabeled_batches[0][0].float().to(device))
# print(f"shape: {guesses.shape}")
# print(guesses)



In [79]:

def guess_and_sharpen(augmented_unlabeled_X, model, T=0.5):
    """
    Assume:
      label_guesses: a list of K guesses
    """
    num_classes = 10
    K = len(augmented_unlabeled_X)
    softmax_guess_array = []
    for i in range(K):
      guess = label_guessing(model, augmented_unlabeled_X[i].float().to(device))
      softmax_guess = torch.softmax(guess, dim=1)
      softmax_guess_array.append(softmax_guess)
    
    p = softmax_guess_array[0]
    for i in range(1,K):
      p += softmax_guess_array[i]


    # Sharpen
    pt = p**(1/T)
    targets_u = pt/pt.sum(dim=1, keepdim=True)
    targets_u = targets_u.detach()

    return targets_u


def mixup(X_hat, data_U, targets_U, alpha):
    data_X = X_hat[0].float().to(device)
    targets_X = X_hat[1]

    targets_X = torch.nn.functional.one_hot(targets_X, num_classes=10).float().to(device)

    batch_size = data_X.size(0)

    # Form W
    all_data = data_X
    all_targets = targets_X
    for i in range(len(data_U)):
      all_data = torch.cat([all_data, data_U[i].float().to(device)], dim=0)
      all_targets = torch.cat([all_targets, targets_U], dim=0)
    # all_data = torch.cat([data_X, data_U1.float().to(device), data_U2.float().to(device)], dim=0)
    # all_targets = torch.cat([targets_X, targets_U, targets_U], dim=0)

    idx = torch.randperm(all_data.size(0))
    W_data = all_data[idx]
    W_targets = all_targets[idx]

    # Mix it up
    lamda = np.random.beta(alpha, alpha)
    lamda = max(lamda, 1 - lamda)

    data_prime = lamda * all_data + (1-lamda) * W_data
    targets_prime = lamda * all_targets + (1-lamda) * W_targets

    return data_prime, targets_prime


In [67]:
def linear_rampup(current, rampup_length=0):
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)

class SemiLoss(object):
  def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch,num_epochs):
      probs_u = torch.softmax(outputs_u, dim=1)

      Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
      Lu = torch.mean((probs_u - targets_u)**2)
      lambda_u=75
      return Lx, Lu, lambda_u * linear_rampup(epoch,num_epochs)
# def SemiLoss(outputs_x, targets_x, outputs_u, targets_u, epoch,num_epoch):
#   probs_u = torch.softmax(outputs_u, dim=1)

#   Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
#   Lu = torch.mean((probs_u - targets_u)**2)
#   lambda_u=75
#   return Lx, Lu,lambda_u * linear_rampup(epoch,num_epoch)

In [68]:
class iter_count:
  def __init__(self,max_len):
    self.max_len=max_len
    self.idx=None
  def next(self):
    if self.idx== None:
      self.idx=0
      return self.idx
    elif self.idx == (self.max_len-1):
      self.idx=0
      return self.idx
    else:
      self.idx+=1
      return self.idx

In [69]:
def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]


In [80]:
def train(labeled_batches, k_unlabeled_batches,epoch,n_train_iter,model, optimizer, ema_optimizer,ema_model,num_epochs):
    # Step 1: Input of the function (Raghav)
      # labeled_batches: List, size: n_batches, each elem: tuple of X(tensor) and targets(tensor)
      # k_unlabeled_batches: [k, num_batches, batch_size, 3, 32, 32] where k is the number of augmentations.
    # n_batches = len(labeled_batches)
    loss = AverageMeter(name='loss_meter',length=n_train_iter)
    loss_x = AverageMeter(name='loss_meter_x',length=n_train_iter)
    loss_u = AverageMeter(name='loss_meter_u',length=n_train_iter)
    wl = AverageMeter(name='w',length=n_train_iter)
    acc = AverageMeter(name='acc',length=n_train_iter)
    alpha = 0.75
    labeled_iter=iter_count(len(labeled_batches))
    unlabeled_iter=iter_count(len(k_unlabeled_batches[0]))
    lc=SemiLoss()
    # Step 2: For loop
    for i in range(n_train_iter):
        labeled_idx=labeled_iter.next()
        unlabeled_idx=unlabeled_iter.next()
        # print(labeled_idx,unlabeled_idx)
        augmented_unlabeled_X = []
        for k in range(len(k_unlabeled_batches)):
          augmented_unlabeled_X.append(k_unlabeled_batches[k][unlabeled_idx])

        targets_u = guess_and_sharpen(augmented_unlabeled_X, model)
        data_prime, targets_prime = mixup(labeled_batches[labeled_idx], augmented_unlabeled_X, targets_u, alpha) 
        batch_size=labeled_batches[labeled_idx][0].size(0)
        # print(f"data_prime size: {data_prime.size()}")
        # print(f"targets_prime size: {targets_prime.size()}")
        data_prime = list(torch.split(data_prime, batch_size))
        data_prime = interleave(data_prime, batch_size)
        log_probs = [model(data_prime[0])]
        for input in data_prime[1:]:
            log_probs.append(model(input))

        # put interleaved samples back
        #logits = interleave(logits, batch_size)
        log_probs = interleave(log_probs, batch_size)
        log_prob_x = log_probs[0]
        log_prob_u = torch.cat(log_probs[1:], dim=0)

        Lx, Lu, w = lc(log_prob_x, targets_prime[:batch_size], log_prob_u, targets_prime[batch_size:], epoch+i/n_train_iter,num_epochs)

        L = Lx + w * Lu
        loss.update(val=L.item())
        loss_x.update(val=Lx.item())
        loss_u.update(val=Lu.item())
        wl.update(val=w)
        #print(type(L))
        optimizer.zero_grad()
        L.backward()
        optimizer.step()
        ema_optimizer.step()
        inputs1, targets1 = labeled_batches[labeled_idx][0].float().to(device), labeled_batches[labeled_idx][1].to(device,non_blocking=True)
        outputs1=ema_model(inputs1)
        prec1, prec5 = accuracy(outputs1, targets1, topk=(1, 5))
        acc.update(val=prec1)
    return loss.avg, loss_x.avg, loss_u.avg, acc.avg
    



In [81]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].contiguous().view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


In [82]:
def validate(validation_images, validation_targets, model, use_cuda):
    val_loss=nn.CrossEntropyLoss()
    num_batches=len(validation_images)
    loss = AverageMeter(name="loss",length=num_batches)
    top1 = AverageMeter(name="top1",length=num_batches)
    top5 = AverageMeter(name="top5",length=num_batches)

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for batch_idx in range(num_batches):
          inputs=validation_images[batch_idx]
          targets=validation_targets[batch_idx]
          # targets = torch.nn.functional.one_hot(targets, num_classes=10)
          if use_cuda:
              inputs, targets = inputs.float().to(device), targets.to(device,non_blocking=True)
          # compute output
          outputs = model(inputs)
          #print(outputs)
          loss_eval = val_loss(outputs, targets)
          # measure accuracy and record loss
          prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
          loss.update(val=loss_eval.item())
          top1.update(val=prec1.item())
          top5.update(val=prec5.item())

            
    return (loss.avg, top1.avg)

In [83]:
import gc
gc.collect()

3843

In [84]:
def create_model(device,ema=False):
  model = WideResNet(num_classes=10)
  model = model.to(device)

  if ema:
      for param in model.parameters():
          param.detach_()

  return model

In [85]:
def save_model(state,checkpoint, filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)

In [86]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount= True)

Mounted at /content/gdrive


In [87]:
lr=0.002
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = create_model(device)
ema_model = create_model(device,ema=True)
optimizer = optim.Adam(model.parameters(), lr)
ema_optimizer= WeightEMA(model, ema_model,lr, alpha=.999)
resume=False
resume_dir="/content/gdrive/MyDrive/DLDS_project/model/checkpoint.pth.tar"
save_dir="path to the directory to save model"
num_epochs=1
n_train_iter=1500
if resume:
  print('==> Resuming from checkpoint..')
  #assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
  # args.out = os.path.dirname(args.resume)
  checkpoint = torch.load(resume_dir)
  best_acc = checkpoint['best_acc']
  start_epoch = checkpoint['epoch']
  model.load_state_dict(checkpoint['state_dict'])
  ema_model.load_state_dict(checkpoint['ema_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])
else:
  start_epoch=0
  best_acc = 0.0

since = time.time()

best_model_wts = copy.deepcopy(model.state_dict())


for epoch in range(start_epoch,num_epochs):
    print('Epoch {}/{}'.format(epoch+1, num_epochs))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train',"val"]:
        if phase == 'train':
            train_loss, train_loss_x, train_loss_u,train_acc=train(labeled_batches, k_unlabeled_batches, epoch, n_train_iter, model, optimizer, ema_optimizer,ema_model,num_epochs)  # Set model to training mode
            epoch_loss=train_loss
            epoch_acc=train_acc

        else:
            loss, top1=validate(validation_images, validation_targets, ema_model, True)
            epoch_loss=loss
            epoch_acc=top1

        
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(
            phase, epoch_loss, epoch_acc))

        # deep copy the model
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            best_ema_model_wts=copy.deepcopy(ema_model.state_dict())
            best_optimizer_wts=copy.deepcopy(optimizer.state_dict())


time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
best_state_dict={'epoch': epoch + 1,
                'state_dict': best_model_wts,
                'ema_state_dict':best_ema_model_wts,
                'best_acc': best_acc,
                'optimizer' : best_optimizer_wts}
# load best model weights
model.load_state_dict(best_model_wts)
ema_model.load_state_dict(best_ema_model_wts)
test_loss, test_acc=validate(test_images, test_targets, ema_model, True)
print("Test Loss: {:4f}".format(test_loss))
print("Test Accuracy: {:4f}".format(test_acc))



Epoch 1/1
----------
train Loss: 0.9612 Acc: 80.1334
val Loss: 39.4411 Acc: 10.4200
Training complete in 3m 25s
Best val Acc: 10.420000
Test Loss: 3.319487
Test Accuracy: 9.990000


In [None]:
len(test_images)

1000

In [None]:
best_state_dict["epoch"]

2

In [None]:
save_model(best_state_dict,checkpoint="/content/gdrive/MyDrive/DLDS_project/model/", filename='checkpoint.pth.tar')