In [124]:
import os
import math
import numpy as np
import torch
from torch import nn
import pandas as pd
from timm import create_model, list_models
from PIL import Image
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

In [85]:
resnet = create_model('resnet18', pretrained=True)
conv1x1 = nn.Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))

In [197]:
class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list):
        x = tensor_list
        b, c, h, w = x.shape
        mask = torch.zeros((b, h, w), dtype=torch.bool, device=x.device)
        not_mask = ~mask
        # y_embed = not_mask.cumsum(1, dtype=torch.float32)
        # x_embed = not_mask.cumsum(2, dtype=torch.float32)

        x_embed = torch.arange(w).unsqueeze(0).repeat(h, 1).unsqueeze(0).repeat(b, 1, 1)
        y_embed = torch.arange(w).unsqueeze(1).repeat(1, h).unsqueeze(0).repeat(b, 1, 1)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t


        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos.to(x.dtype)

In [122]:
def get_emb(sin_inp):
    """
    Gets a base embedding for one dimension with sin and cos intertwined
    """
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)

class PositionalEncoding2D(nn.Module):
    def __init__(self, channels):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super(PositionalEncoding2D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 4) * 2)
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.register_buffer("cached_penc", None, persistent=False)

    def forward(self, tensor):
        """
        :param tensor: A 4d tensor of size (batch_size, x, y, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, y, ch)
        """
        if len(tensor.shape) != 4:
            raise RuntimeError("The input tensor has to be 4d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, y, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
        pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
        emb_x = get_emb(sin_inp_x).unsqueeze(1)
        emb_y = get_emb(sin_inp_y)
        emb = torch.zeros(
            (x, y, self.channels * 2),
            device=tensor.device,
            dtype=tensor.dtype,
        )
        emb[:, :, : self.channels] = emb_x
        emb[:, :, self.channels : 2 * self.channels] = emb_y

        self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)
        return self.cached_penc

In [258]:
torch.randn(1, 1, 128).expand(-1,10,-1).size()

torch.Size([1, 10, 128])

In [259]:
slots_mu = nn.Parameter(torch.randn(1, 1, 128))
slots_sigma = nn.Parameter(torch.abs(torch.randn(1, 1, 128)))
mu = slots_mu.expand(1, 10, -1)
sigma = slots_sigma.expand(1, 10, -1)
mu.shape, sigma.shape

(torch.Size([1, 10, 128]), torch.Size([1, 10, 128]))

In [261]:
initial_slots = nn.Parameter(torch.normal(mu, sigma))
initial_slots.shape

torch.Size([1, 10, 128])

In [267]:
class ScouterAttention(nn.Module):
    def __init__(self, dim, num_concept, iters=3, eps=1e-8, vis=False, power=1, to_k_layer=3):
        super().__init__()
        self.num_slots = num_concept
        self.iters = iters
        self.eps = eps
        self.scale = dim ** (-0.5)

        # random seed init
        slots_mu = nn.Parameter(torch.randn(1, 1, dim))
        slots_sigma = nn.Parameter(torch.abs(torch.randn(1, 1, dim)))
        mu = slots_mu.expand(1, self.num_slots, -1)
        sigma = slots_sigma.expand(1, self.num_slots, -1)
        self.initial_slots = nn.Parameter(torch.normal(mu, sigma))

        # K layer init
        to_k_layer_list = [nn.Linear(dim, dim)]
        for to_k_layer_id in range(1, to_k_layer):
            to_k_layer_list.append(nn.ReLU(inplace=True))
            to_k_layer_list.append(nn.Linear(dim, dim))
        self.to_k = nn.Sequential(
            *to_k_layer_list
        )

        self.vis = vis
        self.power = power

    def forward(self, inputs_pe, inputs):
        b, n, d = inputs_pe.shape
        slots = self.initial_slots.expand(b, -1, -1)
        k, v = self.to_k(inputs_pe), inputs_pe
        print('slot.k.shape:', k.shape)
        for _ in range(self.iters):
            q = slots
            print('slot.q.shape:', q.shape)

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            dots = torch.div(
                dots,
                (dots.sum(2)
                 .expand_as(dots.permute([2, 0, 1]))
                 .permute([1, 2, 0]))
                 * dots.sum(2).sum(1).expand_as(dots.permute([1, 2, 0]))
                    .permute([2, 0, 1]))
            attn = torch.sigmoid(dots)
            print('slot.dots.shape:', dots.shape)
            print('slot.attn.shape:', attn.shape)

            attn2 = attn / (attn.sum(dim=-1, keepdim=True) + self.eps)
            print('slot.attn2.shape:', attn2.shape)
            updates = torch.einsum('bjd,bij->bid', inputs, attn2)
            print('slot.updates.shape:', updates.shape)
            break
        return updates, attn


class MainModel(nn.Module):
    def __init__(self, args, vis=False):
        super(MainModel, self).__init__()
        self.args = args
        self.pre_train = args.pre_train
        if "18" not in args.base_model:
            self.num_features = 2048
        else:
            self.num_features = 512
        self.feature_size = args.feature_size
        self.drop_rate = 0.0
        hidden_dim = 128
        num_concepts = args.num_cpt
        num_classes = args.num_classes
        self.back_bone = create_model('resnet18', pretrained=True)
        self.activation = nn.Tanh()
        self.vis = vis

        if not self.pre_train:
            self.conv1x1 = nn.Conv2d(self.num_features, hidden_dim, kernel_size=(1, 1), stride=(1, 1))
            self.norm = nn.BatchNorm2d(hidden_dim)
            self.position_emb = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
            self.slots = ScouterAttention(hidden_dim, num_concepts, vis=self.vis)
            self.cls = torch.nn.Linear(num_concepts, num_classes)
        else:
            self.fc = nn.Linear(self.num_features, num_classes)
            self.drop_rate = 0

    def forward(self, x):
        x = self.back_bone.forward_features(x)
        features = x
        # x = x.view(x.size(0), self.num_features, self.feature_size, self.feature_size)

        if not self.pre_train:
            x = self.conv1x1(x)
            x = self.norm(x)
            x = torch.relu(x)
            pe = self.position_emb(x)
            x_pe = x + pe

            b, c, h ,w = x.shape
            x = x.reshape((b, c, -1)).permute((0, 2, 1)) # shape: b, c, h*w
            x_pe = x_pe.reshape((b, c, -1)).permute((0, 2, 1)) # shape: b, c, h*w
            print('x.shape:', x.shape, 'x_pe.shape:', x_pe.shape)

            updates, attn = self.slots(x_pe, x)
            if self.args.cpt_activation == "att":
                cpt_activation = attn
            else:
                cpt_activation = updates
            attn_cls = self.scale * torch.sum(cpt_activation, dim=-1)
            cpt = self.activation(attn_cls)
            cls = self.cls(cpt)
            return (cpt - 0.5) * 2, cls, attn, updates
        else:
            x = F.adaptive_max_pool2d(x, 1).squeeze(-1).squeeze(-1)
            if self.drop_rate > 0:
                x = F.dropout(x, p=self.drop_rate, training=self.training)
            x = self.fc(x)
            return x, features

In [270]:
class Arguments:
    def __init__(self) -> None:
        self.pre_train = False
        self.base_model = 'resnet18'
        self.feature_size = 7
        self.num_cpt = 15
        self.num_classes = 200
        self.cpt_activation = 'att'

args = Arguments()
model = MainModel(args)

In [271]:
a, b, c, d = model(torch.randn(32, 3, 224, 224))

x.shape: torch.Size([32, 49, 128]) x_pe.shape: torch.Size([32, 49, 128])
slot.k.shape: torch.Size([32, 49, 128])
slot.q.shape: torch.Size([32, 15, 128])
slot.dots.shape: torch.Size([32, 15, 49])
slot.attn.shape: torch.Size([32, 15, 49])
slot.attn2.shape: torch.Size([32, 15, 49])
slot.updates.shape: torch.Size([32, 15, 128])


In [279]:
dots = torch.rand([32, 15, 49])
a = dots.sum(2).expand_as(dots.permute([2, 0, 1])).permute([1, 2, 0]) * dots.sum(2).sum(1).expand_as(dots.permute([1, 2, 0])).permute([2, 0, 1])

In [281]:
a.size()

torch.Size([32, 15, 49])