In [None]:
import time, os
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as ans 
from tqdm import tqdm 
import shutil 
import tensorflow as tf 
from tensorflow import keras 
from tensorflow.keras import Model
from tensorflow.keras import layers 
from tensorflow.keras.layers import *
import tensorflow_datasets as tfds
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras import *
from datetime import datetime
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.activations import sigmoid
from tensorflow.keras.layers import Dense, Input, UpSampling2D, Conv2DTranspose, Conv2D, add, Add,\
                    Lambda, Concatenate, AveragePooling2D, BatchNormalization, GlobalAveragePooling2D, \
                    Add, LayerNormalization, Activation, LeakyReLU, SeparableConv2D, Softmax, MaxPooling2D

In [None]:
class DataLoader: 
    """
        Class, will be useful for creating the BYOL dataset or dataset for the DownStream task 
            like classification or segmentation.
        Methods:
            __download_data(scope: private)
            __normalize(scope: private)
            __preprocess_img(scope: private)
             __get_valdata(scope: private)
            get_dataset(scope: public)
            __create_tf_dataset(scope: public)
        
        Property:
            dname(dtype: str)        : dataset name(supports cifar10, cifar100).
            n_val(type: int)         : Number of validation data needed, this will be created by splitting the testing
                                       data.
            resize_shape(dtype: int) : Resize shape, bcoz pretrained models, might have a different required shape.
            normalize(dtype: bool)   : bool value, whether to normalize the data or not. 
            n_labeled(dtype: int)    : number of training samples needed to be labeled.
    """
    
    def __init__(self, dname="cifar10", n_val=5000, normalize=True, n_labelled_samples=100): 
        assert dname in ["cifar10", 'cifar100', "svhn"], "supported datasets are cifar10, cifar100,svhn"
        assert n_val <= 10_000, "ValueError: nval value should be <= 10_000"
        
        self.__n_labelled_samples = n_labelled_samples
        train_data, test_data = self.__download_data(dname)
        self.__train_X, self.__train_y = train_data
        self.__dtest_X, self.__dtest_y = test_data 
        
        self.__get_unlabeled_data()
        self.__get_valdata(n_val)        
        self.__normalize() if normalize else None
        
    def __len__(self): 
        return self.__train_X.shape[0] + self.__dtest_X.shape[0]
    
    def __repr__(self): 
        return f"Training Samples: {self.__train_X.shape[0]}, Testing Samples: {self.__dtest_X.shape[0]}"
    
    def __download_data(self, dname):
        """
            Downloads the data from the tensorflow website using the tensorflw.keras.load_data() method.
            Params:
                dname(type: Str): dataset name, it just supports two dataset cifar10 or cifar100
            Return(type(np.ndarray, np.ndarray))
                returns the training data and testing data
        """
        if dname == "cifar10": 
            train_data, test_data = tf.keras.datasets.cifar10.load_data()
            self.__n_labels = len(np.unique(test_data[1]))
            
        if dname == "cifar100": 
            train_data, test_data = tf.keras.datasets.cifar100.load_data()
            self.__n_labels = len(np.unique(test_data[1]))
            
        if dname == "svhn":
            dataset = tfds.load(name='svhn_cropped')
            train_data = dataset['train']
            test_data = dataset['test']
            self.__n_labels = len(np.unique(test_data[1]))
            
        return train_data, test_data
    
    def __normalize(self): 
        """
            this method, will used to normalize the inputs.
        """
        self.__train_X = self.__train_X / 255.0
        self.__dtest_X = self.__dtest_X / 255.0
    
    def __val_test_preprocess(self, x, label):
        x = tf.image.convert_image_dtype(x, dtype=tf.float32)
        label = tf.one_hot(label, self.__n_labels)
        label = tf.reshape(label, (x.shape[0], self.__n_labels))
        return x, label

    def __preprocess_tf_dataset(self, tf_ds, batch_size, transform=False, subset="unlablled"):
        try:
            tf_ds = tf_ds.shuffle(1024, seed=42)
            tf_ds = tf_ds.batch(batch_size, drop_remainder=True)
            if transform:
                if subset == 'unlabelled':
                    tf_ds = tf_ds.map(lambda x: self.__augment(x, is_label=False),
                                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
                    
                else:
                    tf_ds = tf_ds.map(lambda x, y: self.__augment(x, y),
                                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
            
            if subset in ["test", "val"]:
                print("in")
                tf_ds = tf_ds.map(lambda x, y: self.__val_test_preprocess(x, y),
                                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
                
            
            tf_ds = tf_ds.prefetch(tf.data.experimental.AUTOTUNE)   
            return tf_ds
        
        except Exception as err:
            return err
    
    def get_dataset(self, batch_size, subset="unlabelled",
                                        transform=False, k_augmentation=1):
        """
            this method, will gives the byol dataset, which is nothing
            but a tf.data.Dataset object.
            Params:
                batch_size(dtype: int)   : Batch Size.
                subset(dtype: str) : which type of dataset needed
                
            return(type: tf.data.Dataset)
                returns the tf.data.Dataset for intended dataset_type,
                by preprocessing and converting the np data.
        """
        try:
            if subset == "unlabelled":
                tf_ds = tf.data.Dataset.from_tensor_slices((self.__unlabelled_X))
                res = []
                
                for _ in range(k_augmentation):
                    inner_res = self.__preprocess_tf_dataset(
                                            tf_ds=tf_ds,
                                            batch_size=batch_size,
                                            transform=transform,
                                            subset=subset
                                        )
                    res.append(inner_res)
                
                return tf.data.Dataset.zip(tuple(res))
            
            if subset == "labelled":
                tf_ds = tf.data.Dataset.from_tensor_slices((self.__labelled_X, self.__labelled_y))
                tf_ds = self.__preprocess_tf_dataset(
                                        tf_ds=tf_ds,
                                        batch_size=batch_size,
                                        transform=transform,
                                        subset=subset
                                    )
                return tf_ds  
            
            if subset == "val":
                tf_ds = tf.data.Dataset.from_tensor_slices((self.__val_X, self.__val_y))
                tf_ds = self.__preprocess_tf_dataset(
                                        tf_ds=tf_ds,
                                        batch_size=batch_size,
                                        transform=transform,
                                        subset=subset
                                    )
                return tf_ds  
            
            if subset == "test":
                tf_ds = tf.data.Dataset.from_tensor_slices((self.__test_X, self.__test_y))
                tf_ds = self.__preprocess_tf_dataset(
                                        tf_ds=tf_ds,
                                        batch_size=batch_size,
                                        transform=transform,
                                        subset=subset
                                    )
                return tf_ds  
        
        except Exception as err:
            return err
    
    def __get_valdata(self, nval):
        """
            this method is used to create a validation data by randomly sampling from the testing data.
            Params:
                nval(dtype: Int); Number of validation data needed, rest of test_X.shape[0] - nval, will be 
                                  testing data size.
            returns(type; np.ndarray, np.ndarray):
                returns the testing and validation dataset.
        """
        try: 
            ind_arr = np.arange(10_000)
            val_inds = np.random.choice(ind_arr, nval, replace=False)
            test_inds = [i for i in ind_arr if not i in val_inds]

            self.__test_X, self.__test_y = self.__dtest_X[test_inds], self.__dtest_y[test_inds]
            self.__val_X, self.__val_y = self.__dtest_X[val_inds], self.__dtest_y[val_inds]
            
        except Exception as err:
            raise err    
            
    def __get_unlabeled_data(self):
        try:
            ind_arr = np.arange(40_000)
            labelled_inds = np.random.choice(
                                            ind_arr,
                                            self.__n_labelled_samples,
                                            replace=False
                                        )
            unlabelled_inds = [i for i in ind_arr if not i in labelled_inds]
            self.__labelled_X = self.__train_X[labelled_inds]
            self.__labelled_y = self.__train_y[labelled_inds]

            self.__unlabelled_X = self.__train_X[unlabelled_inds]
            self.__unlabelled_y = self.__train_y[unlabelled_inds]
        
        except Exception as err:
            return err 
    
    @tf.function
    def __augment(self, x, label=None, is_label=True):
        try:
            x = tf.image.convert_image_dtype(x, dtype=tf.float32)
            # random left right flipping
            x = tf.image.random_flip_left_right(x)
            # random pad and crop
            x = tf.pad(x, paddings=[(0, 0), (4, 4), (4, 4), (0, 0)], mode='REFLECT')
            x = tf.map_fn(lambda batch: tf.image.random_crop(batch, size=(32, 32, 3)), x)
            if not is_label:
                return x
            else:
                label = tf.one_hot(label, self.__n_labels)
                label = tf.reshape(label, (x.shape[0], self.__n_labels))
                return x, label
            
        except Exception as err:
            return err


In [None]:
mixmatch_dataloader = DataLoader()

In [None]:
unlabelled_ds = mixmatch_dataloader.get_dataset(32, "unlabelled", True, 2)

In [None]:
labelled_ds = mixmatch_dataloader.get_dataset(32, "labelled", True, 1)

In [None]:
val_ds = mixmatch_dataloader.get_dataset(32, "val", False, 1)

In [None]:
val_ds

In [None]:
import tensorflow as tf

class Residual3x3Unit(tf.keras.layers.Layer):
    def __init__(self, channels_in, channels_out, stride, droprate=0., activate_before_residual=False):
        super(Residual3x3Unit, self).__init__()
        self.bn_0 = BatchNormalization(momentum=0.999)
        self.relu_0 = LeakyReLU(alpha=0.1)
        self.conv_0 =Conv2D(channels_out, kernel_size=3, strides=stride, padding='same', use_bias=False)
        self.bn_1 = BatchNormalization(momentum=0.999)
        self.relu_1 = LeakyReLU(alpha=0.1)
        self.conv_1 = Conv2D(channels_out, kernel_size=3, strides=1, padding='same', use_bias=False)
        self.downsample = channels_in != channels_out
        self.shortcut = Conv2D(channels_out, kernel_size=1, strides=stride, use_bias=False)
        self.activate_before_residual = activate_before_residual
        self.dropout = Dropout(rate=droprate)
        self.droprate = droprate

    @tf.function
    def call(self, x, training=True):
        if self.downsample and self.activate_before_residual:
            x = self.relu_0(self.bn_0(x, training=training))
        elif not self.downsample:
            out = self.relu_0(self.bn_0(x, training=training))
        out = self.relu_1(self.bn_1(self.conv_0(x if self.downsample else out), training=training))
        if self.droprate > 0.:
            out = self.dropout(out)
        out = self.conv_1(out)
        return out + (self.shortcut(x) if self.downsample else x)


class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, n_units, channels_in, channels_out, unit, stride, droprate=0., activate_before_residual=False):
        super(ResidualBlock, self).__init__()
        self.units = self._build_unit(n_units, unit, channels_in, channels_out, stride, droprate, activate_before_residual)

    def _build_unit(self, n_units, unit, channels_in, channels_out, stride, droprate, activate_before_residual):
        units = []
        for i in range(n_units):
            units.append(unit(channels_in if i == 0 else channels_out, 
                        channels_out, stride if i == 0 else 1, droprate, activate_before_residual))
        return units

    @tf.function
    def call(self, x, training=True):
        for unit in self.units:
            x = unit(x, training=training)
        return x


class WideResNet(tf.keras.Model):
    def __init__(self, num_classes, depth=28, width=2, droprate=0., input_shape=(None, 32, 32, 3), **kwargs):
        super(WideResNet, self).__init__(input_shape, **kwargs)
        assert (depth - 4) % 6 == 0
        N = int((depth - 4) / 6)
        channels = [16, 16 * width, 32 * width, 64 * width]

        self.conv_0 = tf.keras.layers.Conv2D(channels[0], kernel_size=3, strides=1, padding='same', use_bias=False)
        self.block_0 = ResidualBlock(N, channels[0], channels[1], Residual3x3Unit, 1, droprate, True)
        self.block_1 = ResidualBlock(N, channels[1], channels[2], Residual3x3Unit, 2, droprate)
        self.block_2 = ResidualBlock(N, channels[2], channels[3], Residual3x3Unit, 2, droprate)
        self.bn_0 = BatchNormalization(momentum=0.999)
        self.relu_0 = LeakyReLU(alpha=0.1)
        self.avg_pool = AveragePooling2D((8, 8), (1, 1))
        self.flatten = Flatten()
        self.dense = Dense(num_classes)

    @tf.function
    def call(self, inputs, training=True):
        x = inputs
        x = self.conv_0(x)
        x = self.block_0(x, training=training)
        x = self.block_1(x, training=training)
        x = self.block_2(x, training=training)
        x = self.relu_0(self.bn_0(x, training=training))
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x

In [None]:
def guess_labels(u_aug, model, k):
    u_logits = tf.nn.softmax(model(u_aug[0]), axis=1)
    for _ in range(1, k):
        u_logits = u_logits + tf.nn.softmax(model(u_aug[_]), axis=1)
    u_logits = u_logits / k
    u_logits = tf.stop_gradient(u_logits)
    return u_logits

In [None]:
@tf.function
def sharpen(p, T):
    return tf.pow(p, 1/T) / tf.reduce_sum(tf.pow(p, 1/T), axis=1, keepdims=True)

In [None]:
@tf.function
def mixup(x1, x2, y1, y2, beta):
    beta = tf.maximum(beta, 1-beta)
    x = beta * x1 + (1 - beta) * x2
    y = beta * y1 + (1 - beta) * y2
    return x, y

In [None]:
def mixmatch(model, X, y, U, T, K, beta):
    batch_size = X.shape[0]
    # mean logits from augmentation of unlabelled data.
    mean_logits = guess_labels(U, model, K)
    # using the label smoothing technique for sharpening the probability dis.
    qb = sharpen(mean_logits, T)
    # repeat the probability dis multiple times(K)
    qb = tf.concat([qb for _ in range(K)], axis=0)
    # concatenate both labelled X and unlabelled X and lab_y and unlab_y
    U = tf.concat([_ for _ in U], axis=0)
    XU = tf.concat([X, U], axis=0)
    XUy = tf.concat([y, qb], axis=0)
    # shuffle the combined dataset.
    indices = tf.random.shuffle(tf.range(XU.shape[0]))
    W = tf.gather(XU, indices)
    Wy = tf.gather(XUy, indices)
    # and use the mixup data augmentation with the shuffled data and unshuffle data.
    XU, XUy = mixup(XU, W, XUy, Wy, beta=beta)
    XU = tf.split(XU, K + 1, axis=0)
    XU = interleave(XU, batch_size)
    return XU, XUy
    

In [None]:
def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets

def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [tf.concat(v, axis=0) for v in xy]

In [None]:
def ema_weight_update(model, ema_model, ema_decay):
    ema_vars = ema_model.get_weights()
    model_vars = model.get_weights()
    
    if model_vars:
        for i in range(len(ema_vars)):
            ema_vars[i] = (1 - ema_decay) * model_vars[i] + ema_decay * ema_var[i]
    
    ema_model.set_weights(ema_vars)

In [None]:
def weight_decay(model, weight_decay):
    model_vars = model.get_weights()
    
    if model_vars:
        for i in range(len(model_vars)):
            model_vars[i] = model_vars[i] * (1 - weight_decay)
    
    model.set_weights(model_vars)

In [None]:
def semi_loss(labels_x, logits_x, labels_u, logits_u):
    xe_loss_func = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    loss_xe = xe_loss_func(labels_x, logits_x)
   
    loss_l2u = tf.square(labels_u - tf.nn.softmax(logits_u))
    loss_l2u = tf.reduce_mean(loss_l2u)
    return loss_xe, loss_l2u

In [None]:
def train_step(labelled_batch, unlabelled_batch, model, ema_model, optimizer,
                   xe_loss_tracker, l2_loss_tracker, total_loss_tracker, metric_tracker, **kwargs):
    
    T = kwargs.get('T')
    K = kwargs.get('K')
    beta = kwargs.get('beta')
    ema_decay_rate = kwargs.get('ema_decay_rate')
    weight_decay_rate = kwargs.get('weight_decay_rate')
    lambda_u = kwargs.get('lambda_u')
    

    train_X, train_y = labelled_batch
    train_U = unlabelled_batch
    batch_size = train_X.shape[0]
    
    with tf.GradientTape() as tape:
        # running mixmatch to get a combined training dataset.(unlabeled and labeled)
        XU, XUy = mixmatch(model, train_X, train_y, train_U, T, K, beta)
        logits = [model(XU[0])]
        for batch in XU[1:]:
            logits.append(model(batch))

        logits = interleave(logits, batch_size)
        logits_x = logits[0]
        logits_u = tf.concat(logits[1:], axis=0)

        # compute loss
        xe_loss, l2u_loss = semi_loss(XUy[: batch_size], logits_x, XUy[batch_size: ], logits_u)
        total_loss = xe_loss + lambda_u * l2u_loss
    
    model_params = model.trainable_weights 
    grads = tape.gradient(total_loss, model_params)
    optimizer.apply_gradients(zip(grads, model_params))
    
    # update the weights of both models.
    ema_weight_update(model, ema_model, ema_decay_rate)
    weight_decay(model, weight_decay_rate)
    
    metric_obj_func = tf.keras.metrics.CategoricalAccuracy()
    acc = metric_obj_func(train_y, model(train_X))
    xe_loss_tracker.update_state(xe_loss)
    l2_loss_tracker.update_state(l2u_loss)
    total_loss_tracker.update_state(total_loss)
    metric_tracker.update_state(acc)
    
    return {
        "accuracy": metric_tracker.result(),
        'xe_loss': xe_loss_tracker.result(),
        "l2_loss": l2_loss_tracker.result(),
        "total_loss": total_loss_tracker.result()
    }
    

In [None]:
def test_step(val_batch, model, metric_tracker, loss_tracker):
    X, y = val_batch
    batch_size = X.shape[0]
    # cal the loss with logits and y
    logits = model(X, training=False)
    loss_obj_function = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    loss_val = loss_obj_function(y, logits)
    loss_tracker.update_state(loss_val)
    
    # cal the acc with logits and y
    acc_obj_function = tf.keras.metrics.CategoricalAccuracy()
    acc_val = acc_obj_function(y, logits)
    metric_tracker.update_state(acc_val)
    
    return {
        'accuracy': metric_tracker.result(),
        'loss': loss_tracker.result()
    }

In [None]:
def train(labelled_ds, unlabelled_ds, val_ds, epochs, **kwargs):
    # loss and metrics trackers
    xe_loss_tracker = tf.keras.metrics.Mean()
    l2_loss_tracker = tf.keras.metrics.Mean()
    total_loss_tracker = tf.keras.metrics.Mean()
    train_acc_tracker = tf.keras.metrics.Mean()
    val_loss_tracker = tf.keras.metrics.Mean()
    val_acc_tracker = tf.keras.metrics.Mean()
    
    # arguments
    K = kwargs.get("K")
    beta = kwargs.get('beta')
    T = kwargs.get("T")
    ema_decay_rate = kwargs.get('ema_decay_rate')
    weight_decay_rate = kwargs.get('weight_decay_rate')
    learning_rate = kwargs.get("learning_rate")
    lambda_u = kwargs.get("lambda_u")
    n_classes = kwargs.get("n_classes")
    ckpt_dir = kwargs.get('ckpt_dir')
    log_path = kwargs.get('log_path')
    
    # models and optimizer
    model = WideResNet(n_classes, depth=28, width=2)
    ema_model = WideResNet(n_classes, depth=28, width=2)
    ema_model.set_weights(model.get_weights())
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    
    # checkpoints
    model_ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)
    manager = tf.train.CheckpointManager(model_ckpt, f'{ckpt_dir}/model', max_to_keep=3)
    # for ema model
    ema_ckpt = tf.train.Checkpoint(step=tf.Variable(0), net=ema_model)
    ema_manager = tf.train.CheckpointManager(ema_ckpt, f'{ckpt_dir}/ema', max_to_keep=3)
    
    # summary writers
    train_writer = tf.summary.create_file_writer(f'{log_path}/train')
    val_writer = tf.summary.create_file_writer(f'{log_path}/validation')
    
    model_ckpt.restore(manager.latest_checkpoint)
    ema_ckpt.restore(ema_manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")
        
    for epoch in range(epochs):
        print(f'Epoch; {epoch}')
        for step, unlabelled_batch in tqdm(enumerate(unlabelled_ds), total=len(unlabelled_ds)):
            model_ckpt.step.assign_add(1)
            ema_ckpt.step.assign_add(1)
            for i, labelled_batch in enumerate(labelled_ds):
                if i == 1:
                    break
                res = train_step(labelled_batch,
                                 unlabelled_batch,
                                 model,
                                 ema_model,
                                 optimizer,
                                 xe_loss_tracker,
                                 l2_loss_tracker, 
                                 total_loss_tracker,
                                 train_acc_tracker,
                                 K=K,
                                 T=T, 
                                 beta=beta,
                                 ema_decay_rate=ema_decay_rate,
                                 weight_decay_rate=weight_decay_rate,
                                 lambda_u=lambda_u
                            )
                xe_loss = res.get('xe_loss')
                l2_loss = res.get('l2_loss')
                total_loss = res.get('total_loss')
                accuracy = res.get('accuracy')
        
        for val_batch in val_ds:
            val_res = test_step(val_batch,
                                model,
                                val_acc_tracker,
                                val_loss_tracker
                                )
            val_loss = val_res.get("loss")
            val_accuracy = val_res.get("accuracy")
            
        with train_writer.as_default():
            tf.summary.scalar('xe_loss', xe_loss, step=epoch)
            tf.summary.scalar('l2u_loss', l2_loss, step=epoch)
            tf.summary.scalar('total_loss', total_loss, step=epoch)
            tf.summary.scalar('accuracy', accuracy, step=epoch)
        
        with val_writer.as_default():
            tf.summary.scalar('xe_loss', val_loss, step=epoch)
            tf.summary.scalar('val_accuracy', val_accuracy, step=epoch)   
            
        if epoch % 2 == 0:
            model_save_path = manager.save(checkpoint_number=int(model_ckpt.step))
            ema_save_path = ema_manager.save(checkpoint_number=int(ema_ckpt.step))
            print(f'Saved model checkpoint for epoch {int(model_ckpt.step)} @ {model_save_path}')
            print(f'Saved ema checkpoint for epoch {int(ema_ckpt.step)} @ {ema_save_path}')
            
        print(f"train_loss: {total_loss}, xe_loss: {xe_loss}, l2_loss: {l2_loss}, train_accuracy: {accuracy}")
        print(f"val_loss: {val_loss}, val_accuracy: {val_accuracy}")
    
    for writer in [train_writer, val_writer]:
        writer.flush()

In [None]:
train(labelled_ds=labelled_ds,
      unlabelled_ds=unlabelled_ds,
      val_ds=val_ds,
      epochs=20,
      learning_rate=0.01,
      n_classes=10,
      beta=0.75,
      K=2,
      T=0.5,
      ema_weight_decay=0.999,
      weight_decay_rate=0.02,
      lambda_u=0.01,
      log_path='tensorboard_logs',
      ckpt_dir="saved"
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir tensorboard_logs