In [None]:
# Imports
import os
import torch.cuda, torch.utils.data, torch.nn, torch.optim, torch
import torchvision.transforms, torchvision.datasets.folder, torchvision.utils
from sklearn.cluster import KMeans
from cv2 import imread as img_read
import numpy as np
import matplotlib.pyplot as plt
import time

In [None]:
# Addresses

class Address:
    def __init__(self):
        '''
        Stores all the addresses used in project
        '''
        # Inputs
        self.data = "../input/clevertex/dataset"
        self.img_train = os.path.join(self.data, "images/train")
        self.img_val = os.path.join(self.data, "images/val")
        self.mask_train = os.path.join(self.data, "masks/train")
        self.mask_val = os.path.join(self.data, "masks/val")

        # Models
        self.model = "results/"
        self.slot_attention = os.path.join(self.model, "slot_attention")

        # Temp
        self.temp = "temp/"

    def create_dir(self, dir_list = None):
        '''
        Function to create directories in dir_list. If dir_list is None then create all directories of address.
        '''
        if dir_list == None:
            dir_list = [self.model, self.temp, self.slot_attention]
        for address in dir_list:
            if not os.path.exists(address):
                os.mkdir(address)

    def _delete_folder_content(self, folder_addr):
        '''
        Deletes all the content of folder_addr
        '''
        if os.path.exists(folder_addr):
            for file in os.listdir(folder_addr):
                address = os.path.join(folder_addr, file)
                if os.path.isdir(address):
                    self._delete_folder_content(address)
                    os.removedirs(address)
                else:
                    os.remove(address)

    def clean(self, file_list = None):
        '''
        Deletes all the content in file_list
        '''
        if file_list == None:
            file_list = [self.temp]
        for address in file_list:
            self._delete_folder_content(address)

addr = Address()
addr.clean()
# addr.clean([addr.slot_attention])
addr.create_dir()

In [None]:
class HyperParameters:
    def __init__(self):
        '''
        Stores all Hyperparameters used for training of model
        '''
        # Training
        self.batch_size = 2
        self.resolution = (128, 128)
        self.num_epoch = 38
        self.grad_clip = 1.0

        # Data
        self.num_train = 40000
        self.num_val = 10000
        self.train_step = self.num_epoch*(self.num_train//self.batch_size)
        
        # Learning Rate
        self.lr = 4e-4
        self.warmup_step = self.train_step//50
        self.decay_step = self.train_step//5
        self.decay_rate = 0.5

        # Encoder
        self.dim_input = 64
        self.shift = 3

        # Slot Attention
        self.dim_slot = 64
        self.dim_projected = 64
        self.dim_mlp_slot = 128
        self.num_slot = 11
        self.num_iter_slot = 3
        self.epsilon = 1e-8

        # Decoder
        self.decoder_channel = 64

        # Evaluation
        self.ari_batch_size = 2

    def lr_schedule(self, step):
        '''
        Getting learning rate as function of train steps completed
        '''
        if step <= self.warmup_step:
            return step/self.warmup_step
        else:
            return self.decay_rate**((step-self.warmup_step)/self.decay_step)

    def create_report(self, addr):
        with open(os.path.join(addr, 'param.txt'), 'w') as file:
            file.writelines([
                f'Training:',
                f'\n\tBatch Size:       {self.batch_size}',
                f'\n\tResolution:       {self.resolution}',
                f'\n\tNum Epoch:        {self.num_epoch}',
                f'\n\tGrad Clip:        {self.grad_clip}',
                f'\n\nData:',  
                f'\n\tNum Train:        {self.num_train}',
                f'\n\tNum Val:          {self.num_val}',
                f'\n\tTrain Step:       {self.train_step}',
                f'\n\nLearning Rate:',  
                f'\n\tlr:               {self.lr}',
                f'\n\tWarmup Step:      {self.warmup_step}',
                f'\n\tDecay Step:       {self.decay_step}',
                f'\n\tDecay Rate:       {self.decay_rate}',
                f'\n\nEncoder:',  
                f'\n\tDim Input:        {self.dim_input}',
                f'\n\tShift:            {self.shift}',
                f'\n\nSlot Attention:',
                f'\n\tDim Slot:         {self.dim_slot}',
                f'\n\tDim Projected:    {self.dim_projected}',
                f'\n\tDim MLP slot:     {self.dim_mlp_slot}',
                f'\n\tNum Slot:         {self.num_slot}',
                f'\n\tNum Iter Slot:    {self.num_iter_slot}',
                f'\n\tEpsilon:          {self.epsilon}',
                f'\n\nDecoder:',  
                f'\n\tDecoder Channel:  {self.decoder_channel}',
                f'\n\nEvaluation:',
                f'\n\tARI Batch Size:   {self.ari_batch_size}'
            ])

param = HyperParameters()

In [None]:
# Random Seed and CUDA

random_seed = 68
device = "cpu"
torch.manual_seed(random_seed)
np.random.seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
    device = "cuda"
print(f"Working with device {device}")

In [None]:
# Dataset

class DataSet(torch.utils.data.Dataset):
    def __init__(self, resolution, address_img, address_mask = None):
        self.address_img = address_img
        self.address_mask = address_mask
        self.img_list = sorted(os.listdir(self.address_img))
        if self.address_mask is not None:
            self.mask_list = sorted(os.listdir(self.address_mask))
            self.mask_dict = {(0, 0, 0): 0}
            self.num_category = 1
        self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                         torchvision.transforms.Resize(resolution, antialias=False)])

    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img_addr = os.path.join(self.address_img, self.img_list[idx])
        img = torchvision.datasets.folder.default_loader(img_addr)

        if self.address_mask is not None:
            mask_addr = os.path.join(self.address_mask, self.mask_list[idx])
            mask = img_read(mask_addr)
            new_mask = torch.zeros(mask.shape[0], mask.shape[1], dtype=torch.long)
            for i in range(mask.shape[0]):
                for j in range(mask.shape[1]):
                    color = tuple(mask[i, j, :])
                    if color not in self.mask_dict:
                        self.mask_dict[color] = self.num_category
                        self.num_category += 1
                    new_mask[i,j] = self.mask_dict[color]

            return {
                'img': self.transform(img).to(torch.float),
                'mask': new_mask
            }
        
        return self.transform(img).to(torch.float)

class Data:
    def __init__(self, address: Address, param: HyperParameters, device = device):
        '''
        Creates DataLoader and DataSet for both train and val split
        '''
        self.address = address
        self.device = device

        # Dataset for training
        self.dataset_train_without_mask = DataSet(param.resolution, address.img_train)
        self.dataset_val_without_mask = DataSet(param.resolution, address.img_val)

        # DataLoader for training
        self.loader_train_without_mask = torch.utils.data.DataLoader(self.dataset_train_without_mask,
                                                                     batch_size=param.batch_size,
                                                                     shuffle=True)
        self.loader_val_without_mask = torch.utils.data.DataLoader(self.dataset_val_without_mask,
                                                                   batch_size=param.batch_size,
                                                                   shuffle=False)

        # Dataset for evaluation
        self.dataset_train_with_mask = DataSet(param.resolution, address.img_train, address.mask_train)
        self.dataset_val_with_mask = DataSet(param.resolution, address.img_val, address.mask_val)

        # DataLoader for evaluation
        self.loader_train_with_mask = torch.utils.data.DataLoader(self.dataset_train_with_mask,
                                                                  batch_size=param.ari_batch_size,
                                                                  shuffle=False,
                                                                  collate_fn=self.collate,
                                                                  num_workers=4)
        self.loader_val_with_mask = torch.utils.data.DataLoader(self.dataset_val_with_mask,
                                                                batch_size=param.ari_batch_size,
                                                                shuffle=False,
                                                                collate_fn=self.collate)

    def collate(self, batch):
        img = [elem['img'] for elem in batch]
        mask = [elem['mask'] for elem in batch]

        return {
            'img': torch.stack(img),
            'mask': torch.stack(mask)
        }

data = Data(addr, param)

In [None]:
# Models

def create_grid(resolution):
    '''
    Creates the grid of size resolution with 4 channels. Each channel representing gradient in [0, 1] for one of the four direction.
    '''
    x_grad = np.linspace(0, 1, resolution[1])
    y_grad = np.linspace(0, 1, resolution[0])
    grid = np.meshgrid(y_grad, x_grad, indexing='ij')
    grid = np.stack(grid, axis = -1)
    grid = np.concatenate([grid, 1-grid], axis=-1)
    grid = np.expand_dims(grid, axis=0)
    return torch.tensor(grid, dtype=torch.float, requires_grad=False)

class ResBlock(torch.nn.Module):
    def __init__(self, num_channel, kernel_size=5, stride=1, padding='same'):
        '''
        Initializes a Resnet Block with one convolution layer having ReLU activation
        '''
        super(ResBlock, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(num_channel, num_channel, kernel_size, stride=stride, padding=padding, dtype=torch.float)
        self.norm1 = torch.nn.InstanceNorm2d(num_channel)
        self.activation1 = torch.nn.ReLU()

        self.conv2 = torch.nn.Conv2d(num_channel, num_channel, kernel_size, stride=1, padding='same', dtype=torch.float)
        self.norm2 = torch.nn.InstanceNorm2d(num_channel)
        self.activation2 = torch.nn.ReLU()

        self.project = True if (stride != 1) else False
        if self.project:
            self.conv_project = torch.nn.Conv2d(num_channel, num_channel, kernel_size, stride=stride, padding=padding, dtype=torch.float)

    def forward(self, x):
        res = x

        # First convolution
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.activation1(x)

        # Second Convolution
        x = self.conv2(x)
        x = self.norm2(x)
        x = x + (self.conv_project(res) if self.project else res)
        x = self.activation2(x)
        return x

class Encoder(torch.nn.Module):
    def __init__(self, num_channel, shift, device=device):
        '''
        Encodes input image into feature vectors
        '''
        super(Encoder, self).__init__()

        self.device = device
        self.num_channel = num_channel
        self.shift = shift

        # CNN backbone of Encoder    
        cnn = [torch.nn.Conv2d(3, num_channel, 5, stride=1, padding='same', dtype=torch.float), torch.nn.ReLU()]
        for _ in range(shift):
            cnn.append(ResBlock(num_channel, 5, stride=2, padding=2))
            cnn.append(ResBlock(num_channel, 5, stride=1, padding='same'))
        cnn.append(ResBlock(num_channel, 5, stride=1, padding='same'))
        self.cnn = torch.nn.Sequential(*cnn)
        
        # Positional Embedding
        self.register_buffer('grid', None)
        self.embed = torch.nn.Linear(4, num_channel, dtype=torch.float)

        # Flatten
        self.flatten = torch.nn.Flatten(start_dim=1, end_dim=2)

        # Layer Norm
        self.layer_norm = torch.nn.LayerNorm(num_channel, dtype=torch.float)

        # MLP
        self.mlp = torch.nn.Sequential(torch.nn.Linear(num_channel, num_channel, dtype=torch.float),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(num_channel, num_channel, dtype=torch.float))

    def forward(self, x: torch.Tensor):
        x = self.cnn(x)     # CNN

        x = x.permute((0, 2, 3, 1))     # Permuting Channel axis at the end

        if self.grid is None:
            self.grid = create_grid(x.shape[1:3]).to(self.device)
        x = x + self.embed(self.grid)   # Positional Embedding

        x = self.flatten(x)     # Flatten into feature vector

        x = self.layer_norm(x)  # Layer Normalization

        x = self.mlp(x)     # MLP

        return x

class SlotAttention(torch.nn.Module):
    def __init__(self, dim_input, dim_slot, dim_projected, dim_mlp, num_slot, num_iter, epsilon, device=device):
        '''
        Implementes Slot Attention
        '''
        super(SlotAttention, self).__init__()

        self.device = device
        self.dim_input = dim_input
        self.dim_slot = dim_slot
        self.dim_projected = dim_projected
        self.dim_mlp = dim_mlp
        self.num_slot = num_slot
        self.num_iter = num_iter
        self.epsilon = epsilon

        # Layer Norm
        self.layer_norm_inp = torch.nn.LayerNorm(dim_input, dtype=torch.float)
        self.layer_norm_slot = torch.nn.LayerNorm(dim_slot, dtype=torch.float)

        # Slot Initialization Parameters
        self.slots = torch.nn.Parameter(torch.randn(1, self.num_slot, self.dim_slot, dtype=torch.float))

        # Projection matrix
        self.project_key = torch.nn.Linear(self.dim_input, self.dim_projected, bias=False, dtype=torch.float)
        self.project_value = torch.nn.Linear(self.dim_input, self.dim_projected, bias=False, dtype=torch.float)
        self.project_query = torch.nn.Linear(self.dim_slot, self.dim_projected, bias=False, dtype=torch.float)

        # Slot update
        self.gru = torch.nn.GRUCell(self.dim_projected, self.dim_slot)
        self.mlp = torch.nn.Sequential(torch.nn.Linear(self.dim_slot, self.dim_mlp),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(self.dim_mlp, self.dim_slot))

    def forward(self, x):
        x = self.layer_norm_inp(x)          # Layer Normalization

        slots = torch.tile(self.slots, dims = (x.shape[0], 1, 1))                  # Initializing Slot

        for t in range(self.num_iter):
            slots_prev = slots
            slots = self.layer_norm_slot(slots)         # Layer Normalization

            # Computing Attention
            attn = torch.matmul(self.project_key(x), self.project_query(slots).permute(0, 2, 1))   # Dot product of key and query
            attn /= np.sqrt(self.dim_projected)                                                    # Setting SoftMax temperature
            attn = torch.nn.functional.softmax(attn, dim=-1) + self.epsilon                        # Softmax with numerical stability

            # Updating Slot
            weights = attn/torch.sum(attn, dim=1, keepdim=True)                         # Calculating weights for update
            updates = torch.bmm(weights.permute(0, 2, 1), self.project_value(x))        # Update array
            slots = self.gru(updates.reshape(-1, self.dim_projected), slots_prev.reshape(-1, self.dim_slot))    # GRU update
            slots = slots.reshape(-1, self.num_slot, self.dim_slot)
            slots = slots + self.mlp(slots)                                                    # Residual MLP update
        
        return slots

class SBDecoder(torch.nn.Module):
    def __init__(self, num_slot, dim_slot, resolution, num_channel, shift, device=device):
        '''
        Implements Spatial Broadcast Decoder
        '''
        super(SBDecoder, self).__init__()

        self.device = device
        self.num_slot = num_slot
        self.dim_slot = dim_slot
        self.resolution = resolution
        self.num_channel = num_channel
        self.shift = shift
        self.shifted_resolution = (resolution[0]>>shift, resolution[1]>>shift)

        # Grid for spatial feature
        self.register_buffer('grid', 2*create_grid(self.shifted_resolution).permute(0, 3, 1, 2)[:, :2, :, :]-1)

        # CNN Decoder
        cnn = [torch.nn.Conv2d(2+dim_slot, num_channel, 5, stride=1, padding='same', dtype=torch.float), torch.nn.ReLU()]
        for _ in range(shift):
            cnn.append(torch.nn.ConvTranspose2d(num_channel, num_channel, 5, stride=2, padding=2, output_padding=1, dtype=torch.float))
            cnn.append(torch.nn.ReLU())
        cnn.append(torch.nn.ConvTranspose2d(num_channel, num_channel, 5, stride=1, padding=2, dtype=torch.float))
        cnn.append(torch.nn.ReLU())
        cnn.append(torch.nn.ConvTranspose2d(num_channel, 4, 5, stride=1, padding=2, dtype=torch.float))
        self.cnn = torch.nn.Sequential(*cnn)

    def forward(self, slots):
        x = slots.reshape((-1, self.dim_slot, 1, 1))
        
        # Creating tiled latents
        x = torch.tile(x, (1, 1, self.shifted_resolution[0], self.shifted_resolution[1]))
        x = torch.concat([x, torch.tile(self.grid, (x.shape[0], 1, 1, 1))], dim=1)

        # Decoding
        x = self.cnn(x)

        # Unstack and Split
        x = x.reshape(-1, self.num_slot, 4, self.resolution[0], self.resolution[1])
        img, mask = torch.split(x, [3, 1], dim=2)

        # Reconstructing Image
        mask = torch.nn.functional.softmax(mask, dim=1)
        recon_img = torch.sum(img*mask, dim=1)
        
        return recon_img, img, mask.squeeze(2), slots

class ObjectDiscovery(torch.nn.Module):
    def __init__(self, param: HyperParameters, device=device):
        '''
        Architecture enclosing encoder, slot attention, spatial broadcast decoder for object discovery task
        '''
        super(ObjectDiscovery, self).__init__()

        self.device = device
        self.encoder = Encoder(param.dim_input, param.shift, device=device)
        self.slot_attention = SlotAttention(param.dim_input,
                                            param.dim_slot,
                                            param.dim_projected,
                                            param.dim_mlp_slot,
                                            param.num_slot,
                                            param.num_iter_slot,
                                            param.epsilon,
                                            device=device)
        self.decoder = SBDecoder(param.num_slot,
                                 param.dim_slot,
                                 param.resolution,
                                 param.decoder_channel,
                                 param.shift,
                                 device=device)
        
    def forward(self, x, slot_only = False):
        x = self.encoder(x)
        x = self.slot_attention(x)
        if slot_only:
            return x
        return self.decoder(x)


In [None]:
# Metric
def ARI(original_mask, predicted_mask, transform):
    max_cat = data.dataset_val_with_mask.num_category
    num_slot = predicted_mask.shape[1]

    # Flattening and Reshaping to get masks of shape (Batch Size, H*W)
    orig_mask = torch.flatten(original_mask, start_dim=1, end_dim=2)
    pred_mask = torch.flatten(torch.argmax(transform(predicted_mask), dim=1), start_dim=1, end_dim=2)

    # One Hot Encoding
    pred_mask_oh = torch.nn.functional.one_hot(pred_mask, num_classes=num_slot).to(torch.float)
    orig_mask_oh = torch.nn.functional.one_hot(orig_mask, num_classes=max_cat)[:, :, 1:].to(torch.float)       # Removing Background

    # Number of non Background Points
    n_points = torch.count_nonzero(orig_mask_oh)

    # Calculating number of objects in common
    nij = torch.bmm(orig_mask_oh.permute(0, 2, 1), pred_mask_oh)
    ai = torch.sum(nij, dim=1)
    bj = torch.sum(nij, dim=2)

    # Calculating ARI
    rindex = torch.sum(nij*(nij-1), dim=(1, 2))
    aindex = torch.sum(ai*(ai-1), dim=1)
    bindex = torch.sum(bj*(bj-1), dim=1)
    expected_rindex = aindex*bindex / (n_points*(n_points-1))
    max_rindex = (aindex + bindex)/2

    return torch.mean((rindex - expected_rindex)/(max_rindex - expected_rindex)).item()

# Learning Model
class LearnModel:
    def __init__(self, model: torch.nn.Module, model_addr, data=data, param=param, device=device):
        '''
        Train, Evaluate and Predict
        '''
        
        self.data = data
        self.param = param
        self.device = device
        self.model = model.to(device)
        self.model_addr = model_addr

        # Addresses
        self.loss_addr = os.path.join(self.model_addr, 'loss.npz')
        self.epoch_addr = lambda epoch: os.path.join(self.model_addr, f'model/{epoch}.pth')
        self.scheduler_addr = lambda epoch: os.path.join(self.model_addr, f'scheduler/{epoch}.pth')
        self.ari_addr = os.path.join(model_addr, "ARI.txt")
        self.slot_addr = os.path.join(model_addr, "slots.npz")
        self.img_addr = os.path.join(model_addr, "gen_img")
        self.vis_addr = os.path.join(model_addr, "visualize")
        addr.create_dir([os.path.join(self.model_addr, 'model'),
                         os.path.join(self.model_addr, 'scheduler'),
                         self.img_addr,
                         self.vis_addr])

    def train(self, epoch_log = True, batch_log = True, overwrite = False):
        optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.param.lr)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = self.param.lr_schedule)
        loss_fn = torch.nn.MSELoss().to(device)
        start_time = time.time()

        # Loss arr
        if os.path.exists(self.loss_addr):
            loss_arr = np.load(self.loss_addr)
            train_loss_arr = list(loss_arr['train'])
            val_loss_arr = list(loss_arr['val'])
        else:
            train_loss_arr = []
            val_loss_arr = []

        if overwrite:
            addr.clean([self.model_addr])
        
        for epoch in range(self.param.num_epoch):
            epoch_addr = self.epoch_addr(epoch)
            scheduler_addr = self.scheduler_addr(epoch)

            # Loading Model if present
            if os.path.exists(epoch_addr) and os.path.exists(scheduler_addr):
                self.model.load_state_dict(torch.load(epoch_addr), strict=False)
                scheduler.load_state_dict(torch.load(scheduler_addr))
                print(f"Loaded model and scheduler at epoch {epoch}")
                continue

            # Training Model
            train_loss = self.train_epoch(optimizer, scheduler, loss_fn, batch_log=batch_log)
            if epoch_log:
                print(f'Epoch: {epoch}\tTrain Loss: {train_loss}\tTime: {time.time()-start_time}')

            # Validating Model
            val_loss = self.validate_epoch(loss_fn, self.data.loader_val_without_mask, batch_log=batch_log)
            if epoch_log:
                print(f'Epoch: {epoch}\tVal Loss: {val_loss}\tTime: {time.time()-start_time}')

            # Saving data
            train_loss_arr.append(train_loss)
            val_loss_arr.append(val_loss)
            np.savez_compressed(self.loss_addr, train=np.array(train_loss_arr), val=np.array(val_loss_arr))     # Saving Loss Array
            torch.save(self.model.state_dict(), epoch_addr)     # Saving Model
            torch.save(scheduler.state_dict(), scheduler_addr)  # Saving Scheduler

            # Printing blank line between each epoch in Log
            if epoch_log:
                print()
        
    def train_epoch(self, optimizer, scheduler, loss_fn, batch_log):
        '''
        Trains model for one epoch
        '''
        epoch_loss = 0
        batch_ct = 0
        dataloader = self.data.loader_train_without_mask
        start_time = time.time()

        self.model.train()          # Set Model to train Mode

        for data in dataloader:
            # Copying data to cuda
            data = data.to(device)

            # Forward Propagation
            recon_img, img, mask, slots = self.model(data)

            # Computing Loss
            loss = loss_fn(data, recon_img)
            epoch_loss += loss.item()

            # Back Propagation
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.param.grad_clip)
            optimizer.step()
            scheduler.step()

            # Update batch count
            batch_ct += 1

            if batch_log and batch_ct%50 == 0:
                print(f"\tBatch {batch_ct}\tLoss: {epoch_loss/batch_ct}\tTime: {time.time()-start_time}")

        return epoch_loss/batch_ct
    
    def validate_epoch(self, loss_fn, dataloader, batch_log):
        '''
        Calculates Loss on data in given dataloader
        '''
        epoch_loss = 0
        batch_ct = 0
        start_time = time.time()

        self.model.eval()           # Set Model to eval mode

        with torch.no_grad():
            for data in dataloader:
                # Copying data to cuda
                data = data.to(device)

                # Forward Propagation
                recon_img, img, mask, slots = self.model(data)

                # Computing Loss
                loss = loss_fn(data, recon_img)
                epoch_loss += loss.item()

                # Update batch count
                batch_ct += 1

                if batch_log and batch_ct%125 == 0:
                    print(f"\tBatch {batch_ct}\tLoss: {epoch_loss/batch_ct}\tTime: {time.time()-start_time}")

        return epoch_loss/batch_ct

    def plot_loss(self, addr = None):
        '''
        Plots Loss vs number of epochs
        '''
        if not os.path.exists(self.loss_addr):
            raise Exception("No Loss Array")
        loss_arr = np.load(self.loss_addr)
        train_arr, val_arr = loss_arr['train'], loss_arr['val']
        num_epoch = train_arr.shape[0]
        x_arr = np.linspace(1, num_epoch, num_epoch)

        if addr is None:
            addr = os.path.join(self.model_addr, 'loss_curve')

        plt.title("Loss Curve")
        plt.xlabel("Number of Epochs")
        plt.ylabel("MSE Loss")
        plt.plot(x_arr, train_arr, label='Train')
        plt.plot(x_arr, val_arr, label='Val')
        plt.legend()
        plt.savefig(addr)

    def best_model(self):
        '''
        Returns Best Model as well as changes self.model in place to best model
        '''
        if not os.path.exists(self.loss_addr):
            raise Exception("No Loss Array")
        loss_arr = np.load(self.loss_addr)
        best_epoch = np.argmin(loss_arr['val'])
        self.model.load_state_dict(torch.load(self.epoch_addr(best_epoch)), strict=False)
        return self.model
    
    def ARI_score(self, dataloader = None, log=True):
        '''
        Calculates ARI score for given dataloader (default validation dataset)
        '''
        self.model.eval()       # Set Model to eval mode
        ari_score = 0
        batch_ct = 0
        start_time = time.time()

        transform = None

        if dataloader is None:
            dataloader = self.data.loader_val_with_mask

        with torch.no_grad():
            for data in dataloader:
                # Loading Data
                img_orig, mask_orig = data['img'].to(device), data['mask'].to(device)

                # Forward Propagation
                recon_img, img, mask_pred, slots = self.model(img_orig)

                # Calculating ARI score
                if transform is None:
                    transform = torchvision.transforms.Resize(mask_orig.shape[1:], antialias=False)
                ari_score += ARI(mask_orig, mask_pred, transform)
                batch_ct += 1

                # Log
                if log and batch_ct%2 == 0:
                    print(f"Batch: {batch_ct}\tARI score: {ari_score/batch_ct}\tTime: {time.time()-start_time}")

        with open(self.ari_addr, mode='w') as file:
            file.write(f"ARI Score: {ari_score/batch_ct}\n")

    def _kmean_slot(self, log):
        '''
        Use scikit Kmean to cluster slots
        '''
        batch_ct = 0
        start_time = time.time()
        dataloader = self.data.loader_train_without_mask

        self.model.eval()           # Set Model to eval mode
        slot_arr = torch.tensor([], device=self.device)

        # Concatenating all slots
        with torch.no_grad():
            for data in dataloader:
                # Copying data to cuda
                data = data.to(device)

                # Forward Propagation
                slots = self.model(data, slot_only = True)

                # Updating slot_arr
                slot_arr = torch.concat((slot_arr, slots.reshape(-1, self.param.dim_slot)))

                # Update batch count
                batch_ct += 1

                if log and batch_ct%125 == 0:
                    print(f"Computed Slot for Batch {batch_ct}\tNumSlot: {slot_arr.shape}\tTime: {time.time()-start_time}")

        # K-Mean Cluster
        slot_arr = slot_arr.to("cpu").detach().numpy()
        kmean = KMeans(n_clusters=self.param.num_slot, n_init='auto', random_state=random_seed)
        kmean.fit(slot_arr)
        label = kmean.labels_
        np.savez_compressed(self.slot_addr, slot=slot_arr, label=label)

    def _sample_slot(self, slot_dict, slot_arr, batch_size):
        '''
        Randomly Samples slot from each cluster of kmean
        '''
        batch_slots = []
        for _ in range(batch_size):
            slots = []
            for val in slot_dict:
                slots.append(slot_arr[np.random.choice(slot_dict[val])])
            batch_slots.append(np.stack(slots))
        return np.stack(batch_slots)

    def slot_lib(self, overwrite = False, log=True):
        '''
        Creates a slot library and generates images from it
        '''
        # Checking if trained model of kmean already present
        if overwrite or not os.path.exists(self.slot_addr):
            self._kmean_slot(log=log)
        saved_arr = np.load(self.slot_addr)
        slot_arr = saved_arr['slot']
        label = saved_arr['label']
        print("Loaded slot array and labels")

        # Storing slots labelwise
        unique_values = np.unique(label)
        slot_dict = {}
        for val in unique_values:
            slot_dict[val] = np.argwhere(label == val).flatten()

        # Generating images via Random Slot Sampling
        with torch.no_grad():
            num_generate = len(self.data.dataset_val_without_mask)
            ct_img = 0
            for i in range(0, num_generate, self.param.batch_size):
                batch_size = min(num_generate-i, self.param.batch_size)
                generated_slots = torch.tensor(self._sample_slot(slot_dict, slot_arr, batch_size), dtype=torch.float32, device=self.device)
                generated_img, _, _, _ = self.model.decoder(generated_slots)
                for img in generated_img:
                    torchvision.utils.save_image(img, os.path.join(self.img_addr, f'{ct_img}.png'))
                    ct_img += 1

    def visualize(self):
        for data in self.data.loader_train_without_mask:
            with torch.no_grad():
                recon_img, img, mask, slots = self.model(data.to(device))
                mask_oh = torch.nn.functional.one_hot(torch.argmax(mask, dim=1), self.param.num_slot).permute(0, 3, 1, 2)

            for i in range(data.shape[0]):
                img_path = os.path.join(self.vis_addr, f'{i}')
                addr.create_dir([img_path])
                torchvision.utils.save_image(data[i], os.path.join(img_path, 'orig.png'))
                torchvision.utils.save_image(recon_img[i], os.path.join(img_path, 'regen.png'))
                for slot_num in range(img.shape[1]):
                    torchvision.utils.save_image(img[i, slot_num], os.path.join(img_path, f'img{slot_num}.png'))
                    torchvision.utils.save_image(mask_oh[i, slot_num]*img[i, slot_num], os.path.join(img_path, f'threshold_img{slot_num}.png'))
            break     

In [None]:
model = ObjectDiscovery(param, device=device).to(device)
learner = LearnModel(model, addr.slot_attention)
learner.param.create_report(learner.model_addr)
learner.train()
learner.plot_loss()
learner.best_model()
learner.ARI_score()
learner.slot_lib()
learner.visualize()
