In [None]:
# Imports
import os
import torch.cuda, torch.utils.data, torch.nn, torch
import torchvision.transforms, torchvision.datasets.folder
import numpy as np

In [None]:
# Addresses

class Address:
    def __init__(self):
        '''
        Stores all the addresses used in project
        '''
        self.data = "../input/clevertex/dataset"
        self.img_train = os.path.join(self.data, "images/train")
        self.img_val = os.path.join(self.data, "images/val")
        self.mask_train = os.path.join(self.data, "masks/train")
        self.mask_val = os.path.join(self.data, "masks/val")

    def create_dir(self, dir_list = None):
        '''
        Function to create directories in dir_list. If dir_list is None then create all directories of address.
        '''
        if dir_list == None:
            dir_list = []
        for address in dir_list:
            if not os.path.exists(address):
                os.mkdir(address)

    def _delete_folder_content(self, folder_addr):
        '''
        Deletes all the content of folder_addr
        '''
        if os.path.exists(folder_addr):
            for file in os.listdir(folder_addr):
                if os.path.isdir(file):
                    self._delete_folder_content(os.path.join(folder_addr, file))
                else:
                    os.remove(os.path.join(folder_addr, file))

    def clean(self, file_list = None):
        '''
        Deletes all the content in file_list
        '''
        if file_list == None:
            file_list = []
        for address in file_list:
            self._delete_folder_content(address)


addr = Address()

In [None]:
class HyperParameters:
    def __init__(self, batch_size=64):
        '''
        Stores all Hyperparameters used for training of model
        '''
        # Training
        self.batch_size = 2
        self.resolution = (128, 128)

        # Encoder
        self.dim_input = 64

        # Slot Attention
        self.dim_slot = 64
        self.dim_projected = 64
        self.dim_mlp_slot = 128
        self.num_slot = 11
        self.num_iter_slot = 3
        self.epsilon = 1e-8

        # Decoder
        self.decoder_channel = 128

param = HyperParameters()

In [None]:
# Random Seed and CUDA

random_seed = 68
device = "cpu"
torch.manual_seed(random_seed)
np.random.seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
    device = "cuda"
print(f"Working with device {device}")

In [None]:
# Dataset

class DataSet(torch.utils.data.Dataset):
    def __init__(self, resolution, address_img, address_mask = None):
        self.address_img = address_img
        self.address_mask = address_mask
        self.img_list = sorted(os.listdir(self.address_img))
        if self.address_mask is not None:
            self.mask_list = sorted(os.listdir(self.address_mask))
        self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                         torchvision.transforms.Resize(resolution)])

    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img_addr = os.path.join(self.address_img, self.img_list[idx])
        img = torchvision.datasets.folder.default_loader(img_addr)

        if self.address_mask is not None:
            mask_addr = os.path.join(self.address_mask, self.mask_list[idx])
            mask = torchvision.datasets.folder.default_loader(mask_addr)

            return {
                'img': self.transform(img).to(torch.float),
                'mask': self.transform(mask).to(torch.float)
            }
        
        return self.transform(img).to(torch.float)

class Data:
    def __init__(self, address: Address, param: HyperParameters, device = device):
        '''
        Creates DataLoader and DataSet for both train and val split
        '''
        self.address = address
        self.device = device

        # Dataset for training
        self.dataset_train_without_mask = DataSet(param.resolution, address.img_train)
        self.dataset_val_without_mask = DataSet(param.resolution, address.img_val)

        # DataLoader for training
        self.loader_train_without_mask = torch.utils.data.DataLoader(self.dataset_train_without_mask,
                                                                     batch_size=param.batch_size,
                                                                     shuffle=True)
        self.loader_val_without_mask = torch.utils.data.DataLoader(self.dataset_val_without_mask,
                                                                   batch_size=param.batch_size,
                                                                   shuffle=False)

        # Dataset for evaluation
        self.dataset_train_with_mask = DataSet(param.resolution, address.img_train, address.mask_train)
        self.dataset_val_with_mask = DataSet(param.resolution, address.img_val, address.mask_val)

        # DataLoader for training
        self.loader_train_with_mask = torch.utils.data.DataLoader(self.dataset_train_with_mask,
                                                                  batch_size=param.batch_size,
                                                                  shuffle=False,
                                                                  collate_fn=self.collate)
        self.loader_val_with_mask = torch.utils.data.DataLoader(self.dataset_val_with_mask,
                                                                batch_size=param.batch_size,
                                                                shuffle=False,
                                                                collate_fn=self.collate)

    def collate(self, batch):
        img = [elem['img'] for elem in batch]
        mask = [elem['mask'] for elem in batch]

        return {
            'img': torch.stack(img),
            'mask': torch.stack(mask)
        }

data = Data(addr, param)

In [None]:
# Models

def create_grid(resolution):
    '''
    Creates the grid of size resolution with 4 channels. Each channel representing gradient in [0, 1] for one of the four direction.
    '''
    x_grad = np.linspace(0, 1, resolution[1])
    y_grad = np.linspace(0, 1, resolution[0])
    grid = np.meshgrid(y_grad, x_grad, indexing='ij')
    grid = np.stack(grid, axis = -1)
    grid = np.concatenate([grid, 1-grid], axis=-1)
    grid = np.expand_dims(grid, axis=0)
    return torch.tensor(grid, dtype=torch.float, requires_grad=False)

class Encoder(torch.nn.Module):
    def __init__(self, num_channel, device=device):
        '''
        Encodes input image into feature vectors
        '''
        super(Encoder, self).__init__()

        self.device = device

        # CNN backbone of Encoder
        self.cnn = torch.nn.Sequential(torch.nn.Conv2d(3, num_channel, 5, stride=1, padding='same', dtype=torch.float),
                                       torch.nn.ReLU(),
                                       torch.nn.Conv2d(num_channel, num_channel, 5, stride=2, padding='same', dtype=torch.float),
                                       torch.nn.ReLU(),
                                       torch.nn.Conv2d(num_channel, num_channel, 5, stride=2, padding='same', dtype=torch.float),
                                       torch.nn.ReLU(),
                                       torch.nn.Conv2d(num_channel, num_channel, 5, stride=1, padding='same', dtype=torch.float),
                                       torch.nn.ReLU())
        
        # Positional Embedding
        self.register_buffer('grid', None)
        self.embed = torch.nn.Linear(4, num_channel, dtype=torch.float)

        # Flatten
        self.flatten = torch.nn.Flatten(start_dim=1, end_dim=2)

        # Layer Norm
        self.layer_norm = torch.nn.LayerNorm(num_channel, dtype=torch.float)

        # MLP
        self.mlp = torch.nn.Sequential(torch.nn.Linear(num_channel, num_channel, dtype=torch.float),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(num_channel, num_channel, dtype=torch.float))

    def forward(self, x: torch.Tensor):
        x = self.cnn(x)     # CNN

        x = x.permute((0, 3, 1, 2))     # Permuting Channel axis at the end

        if self.grid is None:
            self.grid = create_grid(x.shape[-2:]).to(self.device)
        x = x + self.embed(self.grid)   # Positional Embedding

        x = self.flatten(x)     # Flatten into feature vector

        x = self.layer_norm(x)  # Layer Normalization

        x = self.mlp(x)     # MLP

        return x

class SlotAttention(torch.nn.Module):
    def __init__(self, dim_input, dim_slot, dim_projected, dim_mlp, num_slot, num_iter, epsilon, device=device):
        '''
        Implementes Slot Attention
        '''
        super(SlotAttention, self).__init__()

        self.device = device
        self.dim_input = dim_input
        self.dim_slot = dim_slot
        self.dim_projected = dim_projected
        self.dim_mlp = dim_mlp
        self.num_slot = num_slot
        self.num_iter = num_iter
        self.epsilon = epsilon

        # Layer Norm
        self.layer_norm_inp = torch.nn.LayerNorm(dim_input, dtype=torch.float)
        self.layer_norm_slot = torch.nn.LayerNorm(dim_slot, dtype=torch.float)

        # Slot Initialization Parameters
        self.mean_slot = torch.nn.Parameter(torch.zeros(1, 1, self.dim_slot, dtype=torch.float))
        self.log_std_slot = torch.nn.Parameter(torch.zeros(1, 1, self.dim_slot, dtype=torch.float))
        torch.nn.init.xavier_uniform_(self.mean_slot)
        torch.nn.init.xavier_uniform_(self.log_std_slot)

        # Projection matrix
        self.project_key = torch.nn.Linear(self.dim_input, self.dim_projected, bias=False, dtype=torch.float)
        self.project_value = torch.nn.Linear(self.dim_input, self.dim_projected, bias=False, dtype=torch.float)
        self.project_query = torch.nn.Linear(self.dim_slot, self.dim_projected, bias=False, dtype=torch.float)

        # Slot update
        self.gru = torch.nn.GRUCell(self.dim_projected, self.dim_slot)
        self.mlp = torch.nn.Sequential(torch.nn.Linear(self.dim_slot, self.dim_mlp),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(self.dim_mlp, self.dim_slot))

    def forward(self, x):
        x = self.layer_norm_inp(x)          # Layer Normalization

        slots = self.mean_slot + torch.exp(self.log_std_slot)*torch.randn((x.shape[0], self.num_slot, self.dim_slot))      # Initializing slot

        for t in range(self.num_iter):
            slots_prev = slots
            slots = self.layer_norm_slot(slots)         # Layer Normalization

            # Computing Attention
            attn = torch.matmul(self.project_key(x), self.project_query(slots).permute(0, 2, 1))   # Dot product of key and query
            attn /= np.sqrt(self.dim_projected)                                                    # Setting SoftMax temperature
            attn = torch.nn.functional.softmax(attn, dim=-1) + self.epsilon                        # Softmax with numerical stability

            # Updating Slot
            weights = attn/torch.sum(attn, dim=1, keepdim=True)                         # Calculating weights for update
            updates = torch.bmm(weights.permute(0, 2, 1), self.project_value(x))        # Update array
            slots = self.gru(updates.reshape(-1, self.dim_projected), slots_prev.reshape(-1, self.dim_slot))    # GRU update
            slots = slots.reshape(-1, self.num_slot, self.dim_slot)
            slots += self.mlp(slots)                                                    # Residual MLP update
        
        return slots

class SBDecoder(torch.nn.Module):
    def __init__(self, num_slot, dim_slot, resolution, num_channel, device=device):
        '''
        Implements Spatial Broadcast Decoder
        '''
        super(SBDecoder, self).__init__()

        self.device = device
        self.num_slot = num_slot
        self.dim_slot = dim_slot
        self.resolution = resolution
        self.num_channel = num_channel

        # Grid for spatial feature
        self.register_buffer('grid', 2*create_grid(resolution).permute(0, 3, 1, 2)[:, :2, :, :]-1)

        # CNN Decoder
        self.cnn = torch.nn.Sequential(torch.nn.Conv2d(2+dim_slot, num_channel, 5, stride=1, padding='same', dtype=torch.float),
                                       torch.nn.ReLU(), 
                                       torch.nn.Conv2d(num_channel, num_channel, 5, stride=1, padding='same', dtype=torch.float),
                                       torch.nn.ReLU(),
                                       torch.nn.Conv2d(num_channel, 4, 5, stride=1, padding='same', dtype=torch.float))

    def forward(self, slots):
        x = slots.reshape((-1, self.dim_slot, 1, 1))
        
        # Creating tiled latents
        x = torch.tile(x, (1, 1, self.resolution[0], self.resolution[1]))
        x = torch.concat([x, torch.tile(self.grid, (x.shape[0], 1, 1, 1))], dim=1)

        # Decoding
        x = self.cnn(x)

        # Unstack and Split
        x = x.reshape(-1, self.num_slot, 4, self.resolution[0], self.resolution[1])
        img, mask = torch.split(x, [3, 1], dim=2)

        # Reconstructing Image
        mask = torch.nn.functional.softmax(mask, dim=1)
        recon_img = torch.sum(img*mask, dim=1)
        
        return recon_img, img, mask, slots

class ObjectDiscovery(torch.nn.Module):
    def __init__(self, param: HyperParameters, device=device):
        '''
        Architecture enclosing encoder, slot attention, spatial broadcast decoder for object discovery task
        '''
        super(ObjectDiscovery, self).__init__()

        self.device = device
        self.encoder = Encoder(param.dim_input, device=device)
        self.slot_attention = SlotAttention(param.dim_input,
                                            param.dim_slot,
                                            param.dim_projected,
                                            param.dim_mlp_slot,
                                            param.num_slot,
                                            param.num_iter_slot,
                                            param.epsilon,
                                            device=device)
        self.decoder = SBDecoder(param.num_slot,
                                 param.dim_slot,
                                 param.resolution,
                                 param.decoder_channel,
                                 device=device)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.slot_attention(x)
        return self.decoder(x)


In [None]:
# Learning Model

class LearnModel:
    def __init__(self):
        pass