In [None]:
try:
    from imutils import paths 
except:
    !pip install imutils 
    from imutils import paths 

In [1]:
from tensorflow import keras 
import tensorflow as tf 
from tensorflow.keras import models 
import os 
#from imutils import paths 
import matplotlib.pyplot as plt 
import numpy as np 

import shutil
import cv2
import random
from dataclasses import dataclass 
from tqdm import tqdm
import tempfile
from tensorflow.keras.layers import Dense, Input, Conv2D, MaxPooling2D, GlobalMaxPooling2D, \
                                                GlobalAveragePooling2D, BatchNormalization, Flatten, ReLU
import tensorflow_addons as tfa  


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
BATCH_SIZE = 128

In [None]:
class DataLoader: 
    """
        Class, will be useful for creating the BYOL dataset or dataset for the DownStream task 
            like classification or segmentation.
        Methods:
            __download_data(scope: private)
            __normalize(scope: private)
            __preprocess_img(scope: private)
            __reshape_downstream_img(scope: private)
             __get_valdata(scope: private)
            get_byol_dataset(scope: public)
            get_downstream_data(scope: public)
        
        Property:
            dname(dtype: str)        : dataset name(supports cifar10, cifar100).
            byol_augmentor(type      : ByolAugmentor): byol augmentor instance/object.
            nval(type: int)          : Number of validation data needed, this will be created by splitting the testing
                                       data.
            resize_shape(dtype: int) : Resize shape, bcoz pretrained models, might have a different required shape.
            normalize(dtype: bool)   : bool value, whether to normalize the data or not. 
    """
    
    def __init__(self, dname="cifar10", byol_augmentor=None, nval=5000,
                                             resize_shape=32, normalize=True, downstream_data=False): 
        assert (byol_augmentor != None or downstream_data), 'Need a BYOL Augment object'
        assert dname in ["cifar10", 'cifar100'], "dname should be either cifar10 or cifar100"
        assert nval <= 10_000, "ValueError: nval value should be <= 10_000"
        
        __train_data, __test_data = self.__download_data(dname)
        self.__train_X, self.__train_y = __train_data
        self.__train_X, self.__train_y = self.__train_X, self.__train_y
      #  self.__train_X, self.__train_y = self.__train_X[: 100], self.__train_y[: 100]
        self.__dtest_X, self.__dtest_y = __test_data 
        self.class_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                                           'dog', 'frog', 'horse', 'sheep', 'truck']
        self.byol_augmentor = byol_augmentor
        self.__get_valdata(nval)
        self.resize_shape = resize_shape
        
        self.__normalize() if normalize else None
        self.min_obj_cov_value = 0.7
        self.color_jitter_value = 0.1
        
    def __len__(self): 
        return self.__train_X.shape[0] + self.__dtest_X.shape[0]
    
    def __repr__(self): 
        return f"Training Samples: {self.__train_X.shape[0]}, Testing Samples: {self.__dtest_X.shape[0]}"
    
    def __download_data(self, dname):
        """
            Downloads the data from the tensorflow website using the tensorflw.keras.load_data() method.
            Params:
                dname(type: Str): dataset name, it just supports two dataset cifar10 or cifar100
            Return(type(np.ndarray, np.ndarray))
                returns the training data and testing data
        """
        if dname == "cifar10": 
            train_data, test_data = tf.keras.datasets.cifar10.load_data()
        if dname == "cifar100": 
            train_data, test_data = tf.keras.datasets.cifar100.load_data()
            
        return train_data, test_data
    
    def __normalize(self): 
        """
            this method, will used to normalize the inputs.
        """
        self.__train_X = self.__train_X / 255.0
        self.__dtest_X = self.__dtest_X / 255.0
    
    def __preprocess_img(self, image): 
        """
            this method, will be used by the get_byol_dataset methos, which does a convertion of 
            numpy data to tensorflow data.
            Params:
                image(type: np.ndarray): image data.
            Returns(type; (np.ndarray, np.ndarray))
                returns the two different augmented views of same image.
        """
        try: 
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.resize(image, (self.resize_shape, self.resize_shape))
            view1 = self.byol_augmentor.augment(image, self.resize_shape)
            view2 = self.byol_augmentor.augment(image, self.resize_shape)
            
            return (view1, view2)
        
        except Exception as err:
            return err
    
    def get_byol_dataset(self, batch_size, dataset_type="train"):
        """
            this method, will gives the byol dataset, which is nothing but a tf.data.Dataset object.
            Params:
                batch_size(dtype: int)    : Batch Size.
                dataset_type(dtype: str)  : which type of dataset needed, (train, test or val)
                
            return(type: tf.data.Dataset)
                returns the tf.data.Dataset for intended dataset_type, by preprocessing and converting 
                the np data.
        """
        try:
            if dataset_type == "train":
                tensorflow_data = tf.data.Dataset.from_tensor_slices((self.__train_X))
                tensorflow_data = (
                tensorflow_data
                    .map(self.__preprocess_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                    .shuffle(1024)
                    .batch(batch_size, drop_remainder=True)
                    .prefetch(tf.data.experimental.AUTOTUNE)
                )
                return tensorflow_data  
            
            if dataset_type == "test":
                tensorflow_data = tf.data.Dataset.from_tensor_slices((self.__test_X))
                tensorflow_data = (
                tensorflow_data
                    .map(self.__preprocess_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                    .shuffle(1024)
                    .batch(batch_size, drop_remainder=True)
                    .prefetch(tf.data.experimental.AUTOTUNE)
                )
                return tensorflow_data  
            
            if dataset_type == "val":
                tensorflow_data = tf.data.Dataset.from_tensor_slices((self.__val_X))
                tensorflow_data = (
                tensorflow_data
                    .map(self.__preprocess_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                    .shuffle(1024)
                    .batch(batch_size, drop_remainder=True)
                    .prefetch(tf.data.experimental.AUTOTUNE)
                )
                return tensorflow_data  
        
        except Exception as err:
            return err
    
    def get_downstream_data(self): 
        """
            this method returns the dataset for the downstream task.
        """
        return (self.__train_X, self.__train_y)#, (self.__val_X, self.__val_y), (self.__test_X, self.__test_y)
    
    def __get_valdata(self, nval):
        """
            this method is used to create a validation data by randomly sampling from the testing data.
            Params:
                nval(dtype: Int); Number of validation data needed, rest of test_X.shape[0] - nval, will be 
                                  testing data size.
            returns(type; np.ndarray, np.ndarray):
                returns the testing and validation dataset.
        """
        try: 
            ind_arr = np.arange(10_000)
            val_inds = np.random.choice(ind_arr, nval, replace=False)
            test_inds = [i for i in ind_arr if not i in val_inds]

            self.__test_X, self.__test_y = self.__dtest_X[test_inds], self.__dtest_y[test_inds]
            self.__val_X, self.__val_y = self.__dtest_X[val_inds], self.__dtest_y[val_inds]
            
        except Exception as err:
            raise err    
            
    def __reshape_downstream_img(self, img, y):
        """
            this method is used to reshape the image, and this method will be used by the get_downstream_tf_dataset
                method.
            Params:
                img(type: tf.Tensor): Image Tensor.
                y(dtype: int): Corresponding label of the image.
            Return(type: tf.Tensor, int)
                returns reshaped image and its label
        """
        img = tf.image.resize(img, (self.resize_shape, self.resize_shape))
        return img, y
        
    def get_downstream_tf_dataset(self, batch_size, dataset_type="train"): 
        """
             this method, will gives the downstream dataset, which is of type tf.data.Dataset object.
            Params:
                batch_size(dtype: int)    : Batch Size.
                dataset_type(dtype: str)  : which type of dataset needed, (train, test or val)
                
            return(type: tf.data.Dataset)
                returns the tf.data.Dataset for intended dataset_type, by preprocessing and converting 
                the np data.
        """
        assert dataset_type in ["train", "test", "val"], "Given dataset type is not valid"
        try:
            if dataset_type == "train":
                tensorflow_data = tf.data.Dataset.from_tensor_slices((self.__train_X, self.__train_y))
                tensorflow_data = (
                tensorflow_data
                    .map(self.__reshape_downstream_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                    .shuffle(1024)
                    .batch(batch_size, drop_remainder=True)
                    .prefetch(tf.data.experimental.AUTOTUNE)
                )
                return tensorflow_data  
            
            if dataset_type == "test":
                tensorflow_data = tf.data.Dataset.from_tensor_slices((self.__test_X, self.__test_X))
                tensorflow_data = (
                tensorflow_data
                    .map(self.__reshape_downstream_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                    .shuffle(1024)
                    .batch(batch_size, drop_remainder=True)
                    .prefetch(tf.data.experimental.AUTOTUNE)
                )
                return tensorflow_data  
            
            if dataset_type == "val":
                tensorflow_data = tf.data.Dataset.from_tensor_slices((self.__val_X, self.__val_y))
                tensorflow_data = (
                tensorflow_data
                    .map(self.__reshape_downstream_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                    .shuffle(1024)
                    .batch(batch_size, drop_remainder=True)
                    .prefetch(tf.data.experimental.AUTOTUNE)
                )
                return tensorflow_data  
        
        except Exception as err:
            return err

In [None]:
class BarlowTwinAugmentor: 
    """
        This class is used for the data augmentation for the byol model.
        Methods: 
            __random_crop_flip_resize(scope: private)
            __random_color_distortion(scope: private)
            augment(scope: public)
    """
    def __init__(self): 
        pass
    
    @tf.function
    def __random_crop_resize(self, image, resize_shape):
        """
            this method does a random crop with height and width of the crop are sampled randomly. it does the 
            crop with the height and width, then it does a resizing again to the original shape. And also it does
            a flip of the image
            Params:
                image(type: tf.Tensor)   : image data of type tensor.
                resize_shape(type: int)  : Size of the image.
            Return(type: tf.Tensor)
                returns the crop and resized image.
        """
        try: 
            rand_size = tf.random.uniform(
                shape=[],
                minval=int(0.75 * resize_shape),
                maxval=1 * resize_shape,
                dtype=tf.int32)

            crop = tf.image.random_crop(image, (rand_size, rand_size, 3))
            crop_resize = tf.image.resize(crop, (resize_shape, resize_shape))
            return crop_resize

        except Exception as err:
            return err
    
    @tf.function
    def __random_flip(self, image):
        """
            this method, will be used to do the random flip of the image, with 0.8 probability of 
            chance.
            Params:
                image(type: tf.Tensor)   : image data of type tensor.
            Return(type: tf.Tensor)
                returns the Flipped image.
                
        """
        try:
            random_val = tf.random.uniform(shape=[])
            if random_val < 0.8:
                image = tf.image.random_flip_left_right(image)
            return image
    
        except Exception as err:
            return err
    
    @tf.function
    def __random_color_distortion(self, image):
        """
            this method, will do the color disortion augmentation for the given image.
            Params:
                image(type: tf.Tensor)   : image data of type tensor.
            Return(type: tf.Tensor)
                returns the random colored disorted image.
        """
        try: 
            color_jitter = tf.random.uniform(shape=[])
            if color_jitter < 0.8:
                image = tf.image.random_brightness(image, max_delta=0.8)
                image = tf.image.random_contrast(image, lower=0.4, upper=1.6)
                image = tf.image.random_saturation(image, lower=0.4, upper=164)
                image = tf.image.random_hue(image, max_delta=0.2)
                image = tf.clip_by_value(image, 0, 1)
            return image
        
        except Exception as error:
            return error
    
    @tf.function
    def __random_grayscale(self, image): 
        """
            this method, will convert the image into grayscale, with probability of 0.8.
            Params:
                image(type: tf.Tensor)   : image data of type tensor.
            Return(type: tf.Tensor)
                returns the randomly grayscale image.
        """
        try: 
            color_drop = tf.random.uniform(shape=[])
            if color_drop < 0.2:
                image = tf.image.rgb_to_grayscale(image)
                image = tf.tile(image, [1, 1, 3])

            return image
        
        except Exception as err:
            return err
    
    @tf.function
    def __random_solarization(self, image): 
        """
            this method, will convert the image into solarization, with probability of 0.8.
            Params:
                image(type: tf.Tensor)   : image data of type tensor.
            Return(type: tf.Tensor)
                returns the randomly solarized image.
        """
        try: 
            random_val = tf.random.uniform(shape=[])
            if random_val < 0.2: 
                image = tf.where(image < 10, image, 255 - image)
            return image
        
        except Exception as err:
            return err
    
    @tf.function
    def __random_gaussian_blur(self, image): 
        """
            this method, will convert the image into gaussian blured img, with probability of 0.8.
            Params:
                image(type: tf.Tensor)   : image data of type tensor.
            Return(type: tf.Tensor)
                returns the randomly blured image.
        """
        try: 
            random_val = tf.random.uniform(shape=[])
            if random_val < 0.2:
                s = np.random.random()
                return tfa.image.gaussian_filter2d(image=image, sigma=s)
            return image
        
        except Exception as err:
            return err
        
    def augment(self, image, resize_shape): 
        """
            this method will include all the augmentation as a pipeline(random crop, random flip, resize, and 
            color disortion), this augment method will be used by DataLoader class.
            Params:
                image(type: tf.Tensor)   : image data of type tensor.
                resize_shape(type: int)  : Size of the image.
            Return(type: tf.Tensor)
                returns the preprocessed image.
                
        """
        try: 
            image = self.__random_crop_resize(image, 32)
            image = self.__random_flip(image)
            image = self.__random_color_distortion(image)
            image = self.__random_gaussian_blur(image)
            image = self.__random_grayscale(image)
            image = self.__random_solarization(image)

            return image
        
        except Exception as error:
            print(error, error)
            return error

In [None]:
bt_augmentor = BarlowTwinAugmentor()

In [None]:
bt_dataloader = DataLoader("cifar10", bt_augmentor)
train_ds = bt_dataloader.get_byol_dataset(BATCH_SIZE, "train")
train_ds

In [None]:
def visualize(train_ds):
    for batch in train_ds.take(1):
        pass

    plt.figure(figsize=(7, 7))

    ax1 = plt.subplot(2, 2, 1)
    ax1.grid(False)
    plt.imshow(batch[0][0].numpy().astype('float32'), interpolation = 'none', vmin = 0, vmax = 1)
    ax2 = plt.subplot(2, 2, 2)
    ax2.grid(False)
    plt.imshow(batch[1][0].numpy().astype('float32'), interpolation = 'none', vmin = 0, vmax = 1)

    ax3 = plt.subplot(2, 2, 3)
    plt.hist(batch[0][0].numpy().ravel())
    ax4 = plt.subplot(2, 2, 4, sharey = ax3, sharex=ax3)
    plt.hist(batch[1][0].numpy().ravel())
    plt.show()

In [None]:
visualize(train_ds)

In [None]:
class ResNet18:
    """Resnet34 class.

        Responsible for the Resnet 34 architecture.
    Modified from
    https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.
    https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.
        View their website for more information.
    """

    def identity_block(self, x, filter):
        # copy tensor to variable called x_skip
        x_skip = x
        # Layer 1
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        x = tf.keras.layers.Activation("relu")(x)
        # Layer 2
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        # Add Residue
        x = tf.keras.layers.Add()([x, x_skip])
        x = tf.keras.layers.Activation("relu")(x)
        return x

    def convolutional_block(self, x, filter):
        # copy tensor to variable called x_skip
        x_skip = x
        # Layer 1
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same", strides=(2, 2))(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        x = tf.keras.layers.Activation("relu")(x)
        # Layer 2
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        # Processing Residue with conv(1,1)
        x_skip = tf.keras.layers.Conv2D(filter, (1, 1), strides=(2, 2))(x_skip)
        # Add Residue
        x = tf.keras.layers.Add()([x, x_skip])
        x = tf.keras.layers.Activation("relu")(x)
        return x

    def __call__(self, shape=(32, 32, 3)):
        # Step 1 (Setup Input Layer)
        x_input = tf.keras.layers.Input(shape)
        x = tf.keras.layers.ZeroPadding2D((3, 3))(x_input)
        # Step 2 (Initial Conv layer along with maxPool)
        x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)
        x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same")(x)
        # Define size of sub-blocks and initial filter size
        block_layers = [2, 2, 2, 2]
        filter_size = 64
        # Step 3 Add the Resnet Blocks
        for i in range(4):
            if i == 0:
                # For sub-block 1 Residual/Convolutional block not needed
                for j in range(block_layers[i]):
                    x = self.identity_block(x, filter_size)
            else:
                # One Residual/Convolutional Block followed by Identity blocks
                # The filter size will go on increasing by a factor of 2
                filter_size = filter_size * 2
                x = self.convolutional_block(x, filter_size)
                for j in range(block_layers[i] - 1):
                    x = self.identity_block(x, filter_size)
        # Step 4 End Dense Network
        x = tf.keras.layers.AveragePooling2D((2, 2), padding="same")(x)
        x = tf.keras.layers.Flatten()(x)
        model = tf.keras.models.Model(inputs=x_input, outputs=x, name="ResNet34")
        return model

In [None]:
class BarlowLoss(keras.losses.Loss):
    """
        this class used for the barlow loss. It uses the cross correlation matrix, with the 
            Invariance term and the Redudancy reduction term.
        methods:
            __get_off_diag(scope: private)
            __cross_corr_matrix_loss(scope: private)
            __normalize(scope: private)
            __cross_corr_matrix(scope: private)
        
        Property:
            lambda(dtype: float)     : Lambda, constant value for trade off between the invariance and 
                redudant rediction.
            batch_size(dtype: int)   ; Number of batch
    """

    def __init__(self, batch_size: int):
        super(BarlowLoss, self).__init__()
        self.lambda_amt = 5e-3
        self.batch_size = batch_size

    def __get_off_diag(self, c):
        """
            Makes the diagonals of the cross correlation matrix zeros.
            Params:
                c(type: tf.Tensor): Cross correlation mstrix(N*D).

            Returns(type: tf.Tensor):
                Returns a tf.tensor which represents the cross correlation
                matrix with its diagonals as zeros.
        """

        zero_diag = tf.zeros(c.shape[-1])
        return tf.linalg.set_diag(c, zero_diag)

    def __cross_corr_matrix_loss(self, c: tf.Tensor):
        """
            Gets the loss based on the cross correlation matrix. We want the diagonals to be 1's 
            and everything else to be zeros to show that the two augmented images are similar.
            Params:
                c(type: tf.Tensor): Cross correlation mstrix(N*D)

            Returns(type: tf.Tensor):
                Returns a tf.tensor which represents the cross correlation
                matrix with its diagonals as zeros.
        """

        # subtracts diagonals by one and squares them(first part)
        c_diff = tf.pow(tf.linalg.diag_part(c) - 1, 2)

        # takes off diagonal, squares it, multiplies with lambda(second part)
        off_diag = tf.pow(self.__get_off_diag(c), 2) * self.lambda_amt

        # sum first and second parts together
        loss = tf.reduce_sum(c_diff) + tf.reduce_sum(off_diag)

        return loss

    def __normalize(self, output):
        """
        this method, will do the batch normaliztion of the input embeddings, without a batch
        normalization, the model produces a bad result.
        Params:
            output(dtype: tf.tensor): the model prediction.

        Returns(dtype: tf.Tensor):
            Returns a normalized version of the model prediction.
        """

        return (output - tf.reduce_mean(output, axis=0)) / tf.math.reduce_std(
            output, axis=0
        )

    def __cross_corr_matrix(self, z_a_norm, z_b_norm):
        """cross_corr_matrix method.

        Creates a cross correlation matrix from the predictions.
        It transposes the first prediction and multiplies this with
        the second, creating a matrix with shape (n_dense_units, n_dense_units).
        See build_twin() for more info. Then it divides this with the
        batch size.

        Arguments:
            z_a_norm: A normalized version of the first prediction.
            z_b_norm: A normalized version of the second prediction.

        Returns:
            Returns a cross correlation matrix.
        """
        return (tf.transpose(z_a_norm) @ z_b_norm) / self.batch_size
    
    @tf.autograph.experimental.do_not_convert
    def call(self, z_a: tf.Tensor, z_b: tf.Tensor) :
        """call method.

        Makes the cross-correlation loss. Uses the CreateCrossCorr
        class to make the cross corr matrix, then finds the loss and
        returns it(see cross_corr_matrix_loss()).

        Arguments:
            z_a: The prediction of the first set of augmented data.
            z_b: the prediction of the second set of augmented data.

        Returns:
            Returns a (rank-0) tf.Tensor that represents the loss.
        """

        z_a_norm, z_b_norm = self.__normalize(z_a), self.__normalize(z_b)
        c = self.__cross_corr_matrix(z_a_norm, z_b_norm)
        loss = self.__cross_corr_matrix_loss(c)
        return loss

In [None]:
def get_projector(input_dim, hidden_dims1, 
                        hidden_dims2, hidden_dims3):
    """
        this function, build the Projection newtork(g), for one view of the network
        Params:
            input_dim (dtye; int)          : Input vector dimensionality
            l2_reg_penalty (dtype; float)  : L2 penalty value.
            hidden_dims1(dtype: Int)       : Hidden layer1 neuron
            hidden_dims2 (dtype: int)      : Hidden layer2 neuron

        Return(type; keras.models.Model):
            The keras model of the Projection network(g)
    """
      
    _input = Input(input_dim, name='Projection input1')
  
    x = Dense(hidden_dims1, name="dense1")(_input)
    x = BatchNormalization(name='bn1')(x)
    x = ReLU(name='relu1')(x)
    x = Dense(hidden_dims2, name="dense2")(_input)
    x = BatchNormalization(name='bn2')(x)
    x = ReLU(name='relu2')(x)
    _output = Dense(hidden_dims3, name='dense3')(x)

    return keras.models.Model(_input, _output, name='Projection')

In [None]:
projection = get_projector(2048, 8192, 8192, 8192)
projection.summary()

In [None]:
def get_encoder_projector(input_shape, hidden_dims1, 
                                hidden_dims2, hidden_dims3):
    """
        this function, will build a encoder(f) with the projection(g)
        Params:
            input_shape (dtye; tuple)      : Input image dimension
            hidden_dims3 (dtype; Int)      : Hidden layer3 neurons, for the projection(q) model.
            hidden_dims1(dtype: Int)       : Hidden layer1 neurons, for the projection(q) model.
            hidden_dims2 (dtype: int)      : Hidden layer2 neurons, for the projection(q) model.

        Return(type; keras.models.Model):
            The keras model of the Projection network(g)
    """
    encoder = ResNet18()()
    last_layer = encoder.layers[-1].output
    
    projection_input_dims = last_layer.shape[-1]
    projection = get_projector(projection_input_dims, hidden_dims1, hidden_dims2, hidden_dims3)
    embeddings = projection(last_layer)
    
    return keras.models.Model(encoder.input, embeddings, name="Barlow-Twin")

In [None]:
enc_proj_model = get_encoder_projector((32, 32, 3), 2048, 2048, 2048)
enc_proj_model.summary()

In [None]:
class BarlowModel(keras.Model):
    """BarlowModel class.

    BarlowModel class. Responsible for making predictions and handling
    gradient descent with the optimizer.

    Attributes:
        model: the barlow model architecture.
        loss_tracker: the loss metric.

    Methods:
        train_step: one train step; do model predictions, loss, and
            optimizer step.
        metrics: Returns metrics.
    """

    def __init__(self):
        super(BarlowModel, self).__init__()
        self.model = get_encoder_projector((32, 32, 3), 2048, 2048, 2048)
        self.loss_tracker = keras.metrics.Mean(name="loss")

    @property
    def metrics(self):
        return [self.loss_tracker]

    def train_step(self, inputs: tf.Tensor) -> tf.Tensor:
        """train_step method.

        Do one train step. Make model predictions, find loss, pass loss to
        optimizer, and make optimizer apply gradients.

        Arguments:
            batch: one batch of data to be given to the loss function.

        Returns:
            Returns a dictionary with the loss metric.
        """
        view1, view2 = inputs
        with tf.GradientTape() as tape:
            z_a = self.model(view1, training=True)
            z_b = self.model(view2, training=True)
            loss = self.loss(z_a, z_b)
        
        params = self.model.trainable_variables
        grads = tape.gradient(loss, params)

        self.optimizer.apply_gradients(zip(grads, params))
        self.loss_tracker.update_state(loss)

        return {"loss": self.loss_tracker.result()}

In [11]:
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self,
               initial_learning_rate: float,
               decay_schedule_fn,
               warmup_steps: int,
               power: float = 1.0,
               name: str = None,):

        super().__init__()
        self.initial_learning_rate = initial_learning_rate
        self.warmup_steps = warmup_steps
        self.power = power
        self.decay_schedule_fn = decay_schedule_fn
        self.name = name

    def __call__(self, step):
        with tf.name_scope(self.name or "WarmUp") as name:
            # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
            # learning rate will be `global_step/num_warmup_steps * init_lr`.
            global_step_float = tf.cast(step, tf.float32)
            warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
            warmup_percent_done = global_step_float / warmup_steps_float
            warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)
            return tf.cond(
                global_step_float < warmup_steps_float,
                lambda: warmup_learning_rate,
                lambda: self.decay_schedule_fn(step - self.warmup_steps),
                name=name,
            )

    def get_config(self):
        return {
                "initial_learning_rate": self.initial_learning_rate,
                "decay_schedule_fn": self.decay_schedule_fn,
                "warmup_steps": self.warmup_steps,
                "power": self.power,
                "name": self.name,
                }

In [12]:
decay_steps = (len(train_ds))*1000
warmup_steps = (len(train_ds))*10
initial_lr = 5e-4 

lr_decayed_fn = tf.keras.experimental.CosineDecay(initial_learning_rate = initial_lr, 
                                                decay_steps = decay_steps)

cosine_with_warmUp = WarmUp(initial_learning_rate = initial_lr,
                          decay_schedule_fn = lr_decayed_fn,
                          warmup_steps = warmup_steps)

optimizer = keras.optimizers.Adam(cosine_with_warmUp)

In [None]:
bt_model = BarlowModel()
bt_loss = BarlowLoss(BATCH_SIZE)

bt_model.compile(optimizer=optimizer, loss=bt_loss)

In [None]:
bt_model.fit(train_ds, epochs=100)

In [None]:
 bt_model.model.save("barlow-twin")