In [17]:
import clip
import copy
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet101, ResNet
from clip.model import CLIP, ModifiedResNet, AttentionPool2d

b, k, c, h, w = 4, 8, 1024, 14, 14
clip_model, clip_preproess = clip.load('RN50', device=torch.device('cpu'))

In [2]:
a = torch.randn(b, k, h, w)
b = torch.randn(b, c, h, w)

In [10]:
clip_model.visual(torch.randn(1,3,224,224)).shape

torch.Size([1, 1024])

In [12]:
class CLIPSpatial(nn.Module):
    def __init__(self, clip_model: CLIP):
        super().__init__()
        assert type(clip_model.visual) == ModifiedResNet
        assert type(clip_model.visual.attnpool) == AttentionPool2d
        self.visual_backbone = copy.deepcopy(clip_model.visual)
        self.visual_backbone.attnpool = nn.Identity()
        self.attnpool = copy.deepcopy(clip_model.visual.attnpool)

    def forward(self, x: torch.Tensor):
        x = self.visual_backbone(x)
        b, c, h, w = x.size()  # [BS, 2048, 7, 7]
        x = x.permute(0, 2, 3, 1)  # [BS, 7, 7, 2048]
        x = x.reshape(-1, c)  # [b*7*7, 2048]
        x = x[..., None, None]  # [b*7*7, 2048, 1, 1]
        x = x.expand(-1, -1, 7, 7)  # [b*7*7, 2048, 7, 7]
        x = self.attnpool(x)  # [b*7*7, 1024]
        x = x.reshape(b, h, w, -1)  # [b, 7, 7, 1024]
        x = x.permute(0, 3, 1, 2)  # [b, 1024, 7, 7]
        return x

In [31]:
class PartCEMClip(nn.Module):
    def __init__(self, backbone: nn.Module, prototypes: torch.Tensor, *, num_parts=8, num_classes=200, dropout=0.1) -> None:
        super().__init__()
        self.k = num_parts + 1
        self.backbone = backbone
        self.dim = 1024
        self.prototypes = nn.Parameter(prototypes.unsqueeze(0)) # shape: [k, dim]
        self.modulations = torch.nn.Parameter(torch.ones((1, self.k, self.dim)))

        self.softmax2d = nn.Softmax2d()
        self.dropout = nn.Dropout1d(p=dropout)
        self.class_fc = nn.Linear(self.dim, num_classes)
    
    def forward(self, x):
        # Pretrained ResNet part of the model
        x = self.backbone(x)

        b, c, h, w = x.shape
        h, w = h*2, w*2
        x = torch.nn.functional.interpolate(x, size=(h, w), mode='bilinear') # shape: [b, 2048, h, w], e.g. h=w=14

        x_flat = x.view(b, c, h*w).permute(0, 2, 1) # shape: [b,h*w,c]
        x_flat_norm = x_flat / F.normalize(x_flat, p=2, dim=-1) # shape: [b,h*w,c]
        proto_norm = self.prototypes / F.normalize(self.prototypes, p=2, dim=-1) # shape: [1,k,c]
        print(x_flat_norm.shape, proto_norm.shape)
        maps = torch.einsum('bnc,bkc->bnk', x_flat_norm, proto_norm.expand(b, -1, -1)) # shape: [b,h*w,k]
        maps = maps.permute(0, 2, 1).reshape(b, -1, h, w) # shape: [b,k,h,w]
        maps = self.softmax2d(maps) # shape: [b,k,h,w]

        parts = torch.einsum('bkhw,bchw->bkchw', maps, x).mean((-1,-2)) # shape: [b,k,h,w], [b,c,h,w] -> [b,k,c]
        parts_modulated = parts * self.modulations # shape: [b,k,c]
        parts_modulated_dropped = self.dropout(parts_modulated) # shape: [b,k,c]
        class_logits = self.class_fc(parts_modulated_dropped) # shape: [b,k,|y|]

        return parts, maps, class_logits

In [32]:
texts = ['back', 'beak', 'belly', 'breast', 'leg', 'tail', 'wing', 'throat', 'background']
with torch.no_grad():
    texts_tokenized = clip.tokenize(texts)
    texts_encoded = clip_model.encode_text(texts_tokenized)
clip_spatial = CLIPSpatial(clip_model)
model = PartCEMClip(backbone=clip_spatial, prototypes=texts_encoded.to(torch.float32))

In [34]:
parts, maps, class_logits = model(torch.randn(4,3,224,224))
parts.shape, maps.shape, class_logits.shape

torch.Size([4, 196, 1024]) torch.Size([1, 9, 1024])


(torch.Size([4, 9, 1024]), torch.Size([4, 9, 14, 14]), torch.Size([4, 9, 200]))