In [1]:
import os 
import cv2
import time 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as ans 
from tqdm import tqdm 
import shutil 
import tensorflow as tf 
from tensorflow import keras 
from tensorflow.keras import Model
from tensorflow.keras import layers 
from tensorflow.keras.layers import *
import tensorflow_datasets as tfds
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras import *
from datetime import datetime
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.layers import Dense, Input, UpSampling2D, Conv2DTranspose, Conv2D, add, Add,\
                    Lambda, Concatenate, AveragePooling2D, BatchNormalization, GlobalAveragePooling2D, \
                    Add, LayerNormalization, Activation, LeakyReLU, SeparableConv2D
try:
    import tensorflow_addons as tfa 
except:
    !pip install tensorflow_addons
    import tensorflow_addons as tfa
    from tensorflow_addons.layers import InstanceNormalization
try:
    import gdown 
except:
    !pip install gdown --quiet
    import gdown
    
try: 
    from imutils import paths 
except:
    !pip install imutils  --quiet
    from imutils import paths
from collections import defaultdict
from tqdm import tqdm
import PIL
import glob
from PIL import Image
import cv2
import collections 
from collections import *

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
class DataLoader: 
    """
        Class, will be useful for creating the BYOL dataset or dataset for the DownStream task 
            like classification or segmentation.
        Methods:
            __download_data(scope: private)
            __normalize(scope: private)
            __preprocess_img(scope: private)
             __get_val_data(scope: private)
            get_dataset(scope: public)
        
        Property:
            dname(dtype: str)        : dataset name(supports cifar10, cifar100).
            augmentor(type           : byol augmentor instance/object.
            nval(type: int)          : Number of validation data needed, this will be created by splitting the testing
                                       data.
            resize_shape(dtype: int) : Resize shape, bcoz pretrained models, might have a different required shape. If Resize shape is 0
                                        no resize, will happen.
            normalize(dtype: bool)   : bool value, whether to normalize the data or not. 
    """
    
    def __init__(self, dname="cifar10", augmentor=None, nval=5000, resize_shape=0, normalize=True): 
        assert dname in ["cifar10", 'cifar100'], "dname should be either cifar10 or cifar100"
        assert nval <= 10_000, "ValueError: nval value should be <= 10_000"
        
        __train_data, __test_data = self.__download_data(dname)
        self.__train_X, self.__train_y = __train_data
        self.__train_X, self.__train_y = self.__train_X, self.__train_y
        self.__dtest_X, self.__dtest_y = __test_data 
        self.__dname = dname
        
        self.augmentor = augmentor
        self.__get_val_data(nval)
        self.resize_shape = resize_shape
        
        
        self.__normalize() if normalize else None
        
    def __len__(self): 
        return self.__train_X.shape[0] + self.__dtest_X.shape[0]
    
    def __repr__(self): 
        return f"Training Samples: {self.__train_X.shape[0]}, Testing Samples: {self.__dtest_X.shape[0]}"
    
    def __download_data(self, dname):
        """
            Downloads the data from the tensorflow website using the tensorflw.keras.load_data() method.
            Params:
                dname(type: Str): dataset name, it just supports two dataset cifar10 or cifar100
            Return(type(np.ndarray, np.ndarray))
                returns the training data and testing data
        """
        if dname == "cifar10": 
            train_data, test_data = tf.keras.datasets.cifar10.load_data()
        if dname == "cifar100": 
            train_data, test_data = tf.keras.datasets.cifar100.load_data()
            
        return train_data, test_data
    
    def __normalize(self): 
        """
            this method, will used to normalize the inputs.
        """
        self.__train_X = self.__train_X / 255.0
        self.__dtest_X = self.__dtest_X / 255.0
    
    def __preprocess_img(self, image, label, augment=False): 
        """
            this method, will be used by the get_byol_dataset methos, which does a convertion of 
            numpy data to tensorflow data.
            Params:
                image(type: np.ndarray): image data.
            Returns(type; (np.ndarray, np.ndarray))
                returns the two different augmented views of same image.
        """
        try: 
            image = tf.cast(image, tf.float32)
            if self.__dname == 'cifar10':
                print(label.shape)
                label = tf.one_hot(label, 10)
                label = tf.reshape(label, (label.shape[0], 10))
             #   label = tf.reshape(label, 10)
            else:
                label = tf.one_hot(label, 100)
                label = tf.reshape(label, (label.shape[0], 100))
               # label = tf.reshape(label, 100)
            
            if self.resize_shape > 0:
                image = tf.image.resize(image, (self.resize_shape, self.resize_shape))
         
            if self.augmentor and augment:
                image = self.paugmentor.augment(image)
           
            return image, label
        
        except Exception as err:
            return err
    
    def get_dataset(self, batch_size, dataset_type="train"):
        """
            this method, will gives the byol dataset, which is nothing but a tf.data.Dataset object.
            Params:
                batch_size(dtype: int)    : Batch Size.
                dataset_type(dtype: str)  : which type of dataset needed, (train, test or val)
                
            return(type: tf.data.Dataset)
                returns the tf.data.Dataset for intended dataset_type, by preprocessing and converting 
                the np data.
        """
        try:
            if dataset_type == "train":
                tensorflow_data = tf.data.Dataset.from_tensor_slices((self.__train_X, self.__train_y))
                tensorflow_data = (
                tensorflow_data
                    .shuffle(1024)
                    .batch(batch_size, drop_remainder=True)
                    .map(lambda x, label: (self.__preprocess_img(x, label,  augment=True)),
                                                 num_parallel_calls=tf.data.AUTOTUNE).cache()
                    
                    .prefetch(tf.data.experimental.AUTOTUNE)
                )
                return tensorflow_data  
            
            if dataset_type == "test":
                tensorflow_data = tf.data.Dataset.from_tensor_slices((self.__test_X, self.__test_y))
                tensorflow_data = (
                tensorflow_data
                    .shuffle(1024)
                    .map(lambda x, label: (self.__preprocess_img(x, label, augment=False)),
                                                 num_parallel_calls=tf.data.AUTOTUNE).cache()
                    .batch(batch_size, drop_remainder=True)
                    .prefetch(tf.data.experimental.AUTOTUNE)
                )
                return tensorflow_data  
            
            if dataset_type == "val":
                tensorflow_data = tf.data.Dataset.from_tensor_slices((self.__val_X, self.__val_y))
                tensorflow_data = (
                tensorflow_data
                     .shuffle(1024)
                    .map(lambda x, label: (self.__preprocess_img(x, label, augment=False)),
                                                 num_parallel_calls=tf.data.AUTOTUNE).cache()
                    .batch(batch_size, drop_remainder=True)
                    .prefetch(tf.data.experimental.AUTOTUNE)
                )
                return tensorflow_data  
        
        except Exception as err:
            return err
    
   
    def __get_val_data(self, nval):
        """
            this method is used to create a validation data by randomly sampling from the testing data.
            Params:
                nval(dtype: Int); Number of validation data needed, rest of test_X.shape[0] - nval, will be 
                                  testing data size.
            returns(type; np.ndarray, np.ndarray):
                returns the testing and validation dataset.
        """
        try: 
            ind_arr = np.arange(10_000)
            val_inds = np.random.choice(ind_arr, nval, replace=False)
            test_inds = [i for i in ind_arr if not i in val_inds]

            self.__test_X, self.__test_y = self.__dtest_X[test_inds], self.__dtest_y[test_inds]
            self.__val_X, self.__val_y = self.__dtest_X[val_inds], self.__dtest_y[val_inds]
            
        except Exception as err:
            raise err    


In [3]:
data_loader = DataLoader(dname="cifar100", augmentor=None, nval=5000, resize_shape=224, normalize=True)

In [4]:
train_ds = data_loader.get_dataset(8, 'train')
train_ds

Cause: mangled names are not yet supported


<_PrefetchDataset element_spec=(TensorSpec(shape=(8, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(8, 100), dtype=tf.float32, name=None))>

In [5]:
val_ds = data_loader.get_dataset(8, 'val')
val_ds

Cause: mangled names are not yet supported


<_PrefetchDataset element_spec=(TensorSpec(shape=(8, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(8, 1, 100), dtype=tf.float32, name=None))>

In [6]:
test_ds = data_loader.get_dataset(8, 'test')
test_ds

Cause: mangled names are not yet supported


<_PrefetchDataset element_spec=(TensorSpec(shape=(8, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(8, 1, 100), dtype=tf.float32, name=None))>

In [None]:
class MLP(keras.Model):
    """
        this class is a implementation of the mlp block described in the swin transformer paper, which contains 
        2 fully connected layer with GelU activation.
    """
    def __init__(self, input_dims=None, hidden_neurons=None,
                 output_neurons=None, act_type="gelu", dropout_rate=0., prefix=''):
        """
            Params:
                input_neurons(dtype: int)   : input dimension for the mlp block, it needed only for .summary() method.
                hidden_neurons(dtype: int)  : number of neurons in the hidden
                                              layer(fully connected layer).
                output_neurons(dtype: iny)  ; number of neurons in the last
                                              layer(fully connected layer) of mlp.
                act_type(type: str)         ; type of activation needed. in paper, GeLU is used.
                dropout_rate(dtype: float)  : dropout rate in the dropout layer.
                prefix(type: str)           : used for the naming the layers.
        """
        super(MLP, self).__init__()
        
        self.input_dims = input_dims
        self.fc_1 = Dense(units=hidden_neurons, name=f"{prefix}/mlp/dense_1")
        if act_type == "relu":
            self.act = tf.keras.layers.ReLU()
        else:
            self.act = None
        self.fc_2 = Dense(units=hidden_neurons, name=f"{prefix}/mlp/dense_2")
        self.drop = Dropout(dropout_rate)

    def call(self, x):
        x = self.fc_1(x)
        if not self.act:
            x = tf.keras.activations.gelu(x)
        else: 
            x = self.act(x)
        x = self.drop(x)
        x = self.fc_2(x)
        x = self.drop(x)
        return x
    
    def summary(self): 
        inputs = Input(shape=self.input_dims)
        return keras.Model(inputs=inputs, outputs=self.call(inputs))

In [7]:
class MLP(tf.keras.layers.Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = Dense(hidden_features, name=f'mlp/fc1')
        self.fc2 = Dense(out_features, name=f'mlp/fc2')
        self.drop = Dropout(drop)

    def call(self, x):
        x = self.fc1(x)
        x = tf.keras.activations.gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


In [8]:
def window_partition(x: tf.Tensor, window_size: int):
    """
        this function is used to create a local window, with the windo_size.
        Params:
            x: (B, H, W, C)
            window_size (int): window size
        Returns:
            windows: (num_windows*B, window_size, window_size, C)
    """
    # batch height width and channles
    B, H, W, C = tf.shape(x).numpy()
    
    x = tf.reshape(
        x, (B, H // window_size, window_size, W // window_size, window_size, C)
    )
    x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
    windows = tf.reshape(x, shape=[-1, window_size, window_size, C])
    return windows

In [9]:
def window_reverse(windows: tf.Tensor, window_size: int, H: int, W: int, C: int):
    """
    Params:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = tf.shape(windows)[0] // tf.cast(H * W / window_size / window_size, dtype="int32")
    x = tf.reshape(windows, (B, H // window_size, W // window_size,
                                             window_size, window_size, -1))
    x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
    return tf.reshape(x, (B, H, W, -1))

In [48]:
def get_relative_position_index(win_h, win_w):
    # get pair-wise relative position index for each token inside the window
    xx, yy = tf.meshgrid(range(win_h), range(win_w))
    coords = tf.stack([yy, xx], axis=0)  # [2, Wh, Ww]
    coords_flatten = tf.reshape(coords, [2, -1])  # [2, Wh*Ww]

    relative_coords = (
        coords_flatten[:, :, None] - coords_flatten[:, None, :]
    )  # [2, Wh*Ww, Wh*Ww]
    relative_coords = tf.transpose(
        relative_coords, perm=[1, 2, 0]
    )  # [Wh*Ww, Wh*Ww, 2]

    xx = (relative_coords[:, :, 0] + win_h - 1) * (2 * win_w - 1)
    yy = relative_coords[:, :, 1] + win_w - 1
    relative_coords = tf.stack([xx, yy], axis=-1)

    return tf.reduce_sum(relative_coords, axis=-1)  # [Wh*Ww, Wh*Ww]


class WindowAttention(layers.Layer):
    """Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        head_dim (int): Number of channels per head (dim // num_heads if not set)
        window_size (tuple[int]): The height and width of the window.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., prefix=''):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.prefix = prefix

        self.qkv = Dense(dim * 3, use_bias=qkv_bias,
                         name=f'{self.prefix}/attn/qkv')
        self.attn_drop = Dropout(attn_drop)
        self.proj = Dense(dim, name=f'{self.prefix}/attn/proj')
        self.proj_drop = Dropout(proj_drop)

    def build(self, input_shape):
        self.relative_position_bias_table = self.add_weight(f'{self.prefix}/attn/relative_position_bias_table',
                                                            shape=(
                                                                (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads),
                                                            initializer=tf.initializers.Zeros(), trainable=True)

        coords_h = np.arange(self.window_size[0])
        coords_w = np.arange(self.window_size[1])
        coords = np.stack(np.meshgrid(coords_h, coords_w, indexing='ij'))
        coords_flatten = coords.reshape(2, -1)
        relative_coords = coords_flatten[:, :,
                                         None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.transpose([1, 2, 0])
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1).astype(np.int64)
        self.relative_position_index = tf.Variable(initial_value=tf.convert_to_tensor(
            relative_position_index), trainable=False, name=f'{self.prefix}/attn/relative_position_index')
        self.built = True

    def call(self, x, mask=None):
        B_, N, C = x.get_shape().as_list()
        qkv = tf.transpose(tf.reshape(self.qkv(
            x), shape=[-1, N, 3, self.num_heads, C // self.num_heads]), perm=[2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = (q @ tf.transpose(k, perm=[0, 1, 3, 2]))
        relative_position_bias = tf.gather(self.relative_position_bias_table, tf.reshape(
            self.relative_position_index, shape=[-1]))
        relative_position_bias = tf.reshape(relative_position_bias, shape=[
                                            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1])
        relative_position_bias = tf.transpose(
            relative_position_bias, perm=[2, 0, 1])
        attn = attn + tf.expand_dims(relative_position_bias, axis=0)

        if mask is not None:
            nW = mask.get_shape()[0]  # tf.shape(mask)[0]
            attn = tf.reshape(attn, shape=[-1, nW, self.num_heads, N, N]) + tf.cast(
                tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), attn.dtype)
            attn = tf.reshape(attn, shape=[-1, self.num_heads, N, N])
            attn = tf.nn.softmax(attn, axis=-1)
        else:
            attn = tf.nn.softmax(attn, axis=-1)

        attn = self.attn_drop(attn)

        x = tf.transpose((attn @ v), perm=[0, 2, 1, 3])
        x = tf.reshape(x, shape=[-1, N, C])
        x = self.proj(x)
        x = self.proj_drop(x)
        return x



In [49]:
def drop_path(inputs, drop_prob, is_training):
    if (not is_training) or (drop_prob == 0.):
        return inputs

    # Compute keep_prob
    keep_prob = 1.0 - drop_prob

    # Compute drop_connect tensor
    random_tensor = keep_prob
    shape = (tf.shape(inputs)[0],) + (1,) * \
        (len(tf.shape(inputs)) - 1)
    random_tensor += tf.random.uniform(shape, dtype=inputs.dtype)
    binary_tensor = tf.floor(random_tensor)
    output = tf.math.divide(inputs, keep_prob) * binary_tensor
    return output


class DropPath(tf.keras.layers.Layer):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def call(self, x, training=None):
        return drop_path(x, self.drop_prob, training)

In [50]:
class SwinTransformerBlock(keras.Model):
    """ Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, input_resolution, num_heads, window_size=7,
                 shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop=0., attn_drop=0., drop_path=0., act_layer='gelu',
                 norm_layer=LayerNormalization): 
    
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim 
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio 
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)

        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
        self.norm1 = norm_layer(epsilon=1e-5, name=f'norm1')

        self.attn = WindowAttention(
                        dim,
                        window_size=self.window_size,
                        num_heads=num_heads,
                        qkv_bias=qkv_bias,
                        attn_drop=attn_drop,
                        proj_drop=drop
                    )

        self.drop_path = DropPath(
            drop_path if (drop_path) > 0. else 0.
        )
        self.norm2 = norm_layer(epsilon=1e-5, name=f'norm2')
        mlp_hidden_dim = int(dim * mlp_ratio)

        self.mlp = MLP(input_dims=None,
                        hidden_neurons=mlp_hidden_dim,
                        output_neurons=None,
                        act_type='gelu',
                        dropout_rate=drop,
                        prefix='swin_transformer'
                      )
        if self.shift_size > 0:
                H, W = self.input_resolution
                img_mask = np.zeros([1, H, W, 1])
                h_slices = (slice(0, -self.window_size),
                            slice(-self.window_size, -self.shift_size),
                            slice(-self.shift_size, None))
                w_slices = (slice(0, -self.window_size),
                            slice(-self.window_size, -self.shift_size),
                            slice(-self.shift_size, None))
                cnt = 0
                for h in h_slices:
                    for w in w_slices:
                        img_mask[:, h, w, :] = cnt
                        cnt += 1

                img_mask = tf.convert_to_tensor(img_mask)
                mask_windows = window_partition(img_mask, self.window_size)
                mask_windows = tf.reshape(
                    mask_windows, shape=[-1, self.window_size * self.window_size])
                attn_mask = tf.expand_dims(
                    mask_windows, axis=1) - tf.expand_dims(mask_windows, axis=2)
                attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
                attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
                self.attn_mask = tf.Variable(
                    initial_value=attn_mask, trainable=False)
        else:
            self.attn_mask = None

    def call(self, x, return_attns=False):
        H, W = self.input_resolution
        B, L, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        
        shortcut = x
        x = self.norm1(x)
        x = tf.reshape(x, (B, H, W, C))
        
        # cyclic shift
        if self.shift_size > 0:
            shifted_x = tf.roll(
                x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2)
            )
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(
            shifted_x, self.window_size
        )  # [num_win*B, window_size, window_size, C]
        x_windows = tf.reshape(
            x_windows, (-1, self.window_size * self.window_size, C)
        )  # [num_win*B, window_size*window_size, C]

        # W-MSA/SW-MSA
        if not return_attns:
            attn_windows = self.attn(
                x_windows, mask=self.attn_mask
            )  # [num_win*B, window_size*window_size, C]
        else:
            attn_windows, attn_scores = self.attn(
                x_windows, mask=self.attn_mask, return_attns=True
            )  # [num_win*B, window_size*window_size, C]
        # merge windows
        attn_windows = tf.reshape(
            attn_windows, (-1, self.window_size, self.window_size, C)
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, H, W
        )  # [B, H', W', C]

        # reverse cyclic shift
        if self.shift_size > 0:
            x = tf.roll(
                shifted_x,
                shift=(self.shift_size, self.shift_size),
                axis=(1, 2),
            )
        else:
            x = shifted_x
        x = tf.reshape(x, (B, H * W, C))

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        if return_attns:
            return x, attn_scores
        else:
            return x
        

In [51]:
class PatchMerging(keras.layers.Layer): 
    """ Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, input_resolution, dim, norm_layer=LayerNormalization):
        super(PatchMerging, self).__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = Dense(2 * dim, use_bias=False, name=f'downsample/reduction')
        self.norm = norm_layer(epsilon=1e-5, name=f'downsample/norm')
    
    def call(self, x): 
        """
        x: B, H * W, C
        """
        H, W = self.input_resolution
        B, L, C = x.get_shape().as_list()
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = tf.reshape(x, shape=[-1, H, W, C])

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = tf.concat([x0, x1, x2, x3], axis=-1)
        x = tf.reshape(x, shape=[-1, (H // 2) * (W // 2), 4 * C])

        x = self.norm(x)
        x = self.reduction(x)

        return x

In [52]:
class BasicLayer(keras.Model):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
    """
    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                         mlp_ratio=4., qkv_bias=True, qk_scale=None,
                         drop=0., attn_drop=0., drop_path=0., norm_layer=LayerNormalization, downsample=None):
        super(BasicLayer, self).__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        
        self.blocks = [
            SwinTransformerBlock(
                dim=dim,
                input_resolution=input_resolution,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
            )
            for i in range(depth)]
       
        # patch merging layer
        self.downsample=None
        if downsample != None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        
    def call(self, x):
        for blk in self.blocks:
            x = blk(x)
        
        if self.downsample != None:
            x = self.downsample(x)
        return x

In [53]:
class PatchEmbed(keras.layers.Layer):
    """ Image to Patch Embedding
    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self, img_size=(224, 224), patch_size=(4, 4), in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        self.img_size = (img_size)
        self.patch_size = (patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = Conv2D(filters=embed_dim,
                           kernel_size=patch_size,
                           strides=patch_size
                          )
        if norm_layer is not None:
            self.norm = norm_layer(epsilon=1e-5, name='norm')
        else:
            self.norm = None
            
    def call(self, x):
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
      #  assert H == self.img_size[0] and W == self.img_size[1], \
       #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        
        x = self.proj(x)
        x = tf.reshape(x, shape=[-1, (H // self.patch_size[0]) * (W // self.patch_size[0]), self.embed_dim])
        if self.norm is not None:
            x = self.norm(x)
        return x

In [54]:
class SwinTransformer(keras.Model): 
    """ Swin Transformer
        A Tensorflow impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        
    """
    def __init__(self, img_size=(224, 224), patch_size=(4, 4), in_chans=3, num_classes=100,
                embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0.,
                attn_drop=0., drop_path_rate=0.1, norm_layer=LayerNormalization, ape=False,
                patch_norm=False, include_top=True, **kwargs):
        
        super(SwinTransformer, self).__init__()
        
        self.num_classes=num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio
        self.include_top = include_top
        
        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution
                
        # absolute postion embedding
        if self.ape:
            self.absolute_pos_embed = self.add_weight('absolute_pos_embed',
                                                      shape=(
                                                          1, num_patches, embed_dim),
                                                      initializer=tf.initializers.Zeros())

        self.pos_drop = Dropout(drop_rate)
        
        # stochastic depth
        dpr = [x for x in np.linspace(0., drop_path_rate, sum(depths))]
        
        # build layers
        basic_layers = tf.keras.Sequential([BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                                                input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                                  patches_resolution[1] // (2 ** i_layer)),
                                                depth=depths[i_layer],
                                                num_heads=num_heads[i_layer],
                                                window_size=window_size,
                                                mlp_ratio=self.mlp_ratio,
                                                qkv_bias=qkv_bias,
                                                qk_scale=qk_scale,
                                                drop=drop_rate,
                                                attn_drop=attn_drop,
                                                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                                                norm_layer=norm_layer,
                                                downsample=PatchMerging if (
                                                    i_layer < self.num_layers - 1) else None) for i_layer in range(self.num_layers)])
        
        
        self.basic_layers = basic_layers
        self.norm = norm_layer(epsilon=1e-5, name='norm')
        self.avgpool = GlobalAveragePooling1D()
        if self.include_top:
            self.head = Dense(num_classes, name='head')
        else:
            self.head = None
            
    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
    
        x = self.basic_layers(x)
        x = self.norm(x)
        x = self.avgpool(x)
        return x
        
    def call(self, x):
        x = self.forward_features(x)
        if self.include_top:
            x = self.head(x)
        return x
        
  #  def build(self, input_shape):
   #     x = Input(shape=input_shape)
    #    return Model(inputs=[x], outputs=self.call(x))
        
    def __summary__(self, input_shape):
        x = Input(shape=input_shape)
        return Model(inputs=[x], outputs=self.call(x))

In [55]:
class SwinTransformerBlock(tf.keras.layers.Layer):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4.,
                 qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=LayerNormalization):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
        

        self.norm1 = norm_layer(epsilon=1e-5, name=f'norm1')
        self.attn = WindowAttention(
                        dim,
                        window_size=(self.window_size, self.window_size),
                        num_heads=num_heads,
                        qkv_bias=qkv_bias,
                        attn_drop=attn_drop,
                        proj_drop=drop
                    )
        self.drop_path = DropPath(
            drop_path if drop_path > 0. else 0.)
        self.norm2 = norm_layer(epsilon=1e-5, name=f'norm2')
        mlp_hidden_dim = int(dim * mlp_ratio)
        
        self.mlp = self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim,
                       drop=drop)
        #self.mlp = Mlp(input_dims=dim,
         #               hidden_neurons=mlp_hidden_dim,
          #              output_neurons=None,
           #             act_type='gelu',
            #            dropout_rate=drop,
             #           prefix='swin_transformer')

    def build(self, input_shape):
        if self.shift_size > 0:
            H, W = self.input_resolution
            img_mask = np.zeros([1, H, W, 1])
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            img_mask = tf.convert_to_tensor(img_mask)
            mask_windows = window_partition(img_mask, self.window_size)
            mask_windows = tf.reshape(
                mask_windows, shape=[-1, self.window_size * self.window_size])
            attn_mask = tf.expand_dims(
                mask_windows, axis=1) - tf.expand_dims(mask_windows, axis=2)
            attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = tf.Variable(
                initial_value=attn_mask, trainable=False, name=f'attn_mask', dtype=tf.float64)
        else:
            self.attn_mask = None

        self.built = True

    def call(self, x):
        H, W = self.input_resolution
        B, L, C = x.get_shape().as_list()

        shortcut = x
        x = self.norm1(x)
        x = tf.reshape(x, shape=[-1, H, W, C])

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = tf.roll(
                x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2])
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = tf.reshape(
            x_windows, shape=[-1, self.window_size * self.window_size, C])
        
   #     x_windows = tf.cast(x_windows, dtype=tf.float32)

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        # merge windows
        attn_windows = tf.reshape(
            attn_windows, shape=[-1, self.window_size, self.window_size, C])
        
        shifted_x = window_reverse(attn_windows, self.window_size, H, W, C)
        

        # reverse cyclic shift
        if self.shift_size > 0:
            x = tf.roll(shifted_x, shift=[
                        self.shift_size, self.shift_size], axis=[1, 2])
        
        else:
            x = shifted_x
        x = tf.reshape(x, shape=[-1, H * W, C])
        
        # FFN
        x = shortcut + self.drop_path(x)
        temp_x = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = temp_x + self.drop_path(x)

        return x



In [56]:
m = SwinTransformer()


In [57]:
m(np.zeros((1, 224, 224, 3)))

<tf.Tensor: shape=(1, 100), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.]], dtype=float32)>

In [58]:
opt = keras.optimizers.Adam()
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)
metric_func = keras.metrics.CategoricalAccuracy()

In [59]:
loss_tracker = keras.metrics.Mean()
def train_step(inputs, model, opt, loss_func, metric_func):
    with tf.GradientTape() as tape:
        X, y = inputs 
        pred_y = model(X)

        loss_val = loss_func(y, pred_y)
        
    params = model.trainable_weights
    grads = tape.gradient(loss_val, params)
    opt.apply_gradients(zip(grads, params))
    
    loss_tracker.update_state(loss_val)
    metric_func.update_state(y, pred_y)
    
    return loss_tracker.result(), metric_func.result()


In [None]:
for epoch in range(1):
    for step, training_batch in tqdm(enumerate(train_ds), total=len(train_ds)):
        res = train_step(training_batch, m, opt, loss_func, metric_func)
    
    print(f'Epoch: {epoch}, loss: {res[0]}, acc: {res[1]}')
    loss_tracker.reset_state()
    metric_func.reset_state()


 10%|▉         | 621/6250 [07:57<1:08:20,  1.37it/s]