In [5]:
import torch
from torchvision.models import resnet101, ResNet101_Weights

backbone = resnet101(weights=ResNet101_Weights.DEFAULT)
net = torch.nn.Sequential(*list(backbone.children())[:-2])
net(torch.randn(4, 3, 224, 224)).shape


torch.Size([4, 2048, 7, 7])

In [2]:
with torch.no_grad():
    res = net(torch.randn(4, 3, 448, 448))
for k, v in res.items():
    print(k, v.shape)

attr_scores torch.Size([4, 312])
class_scores torch.Size([4, 200])
attn_maps torch.Size([4, 9, 28, 28])
part_features torch.Size([4, 3072, 9])


In [55]:
import torch.nn.functional as F

def pres_loss(maps: torch.Tensor):
    loss_pres = F.avg_pool2d(maps[:, :, 2:-2, 2:-2], 3, stride=1).max(-1)[0].max(-1)[0].max(0)[0].mean()
    loss_pres = (1 - loss_pres)
    return loss_pres

def pres_loss1(attn_maps: torch.Tensor):
    maps_pooled = F.avg_pool2d(attn_maps[:, :, 2:-2, 2:-2], 3, stride=1)
    mean_presence = torch.amax(maps_pooled, dim=(0, 2, 3)).mean()
    return 1 - mean_presence

maps = torch.randn(4, 10, 14, 14)

print(pres_loss((torch.clone((maps)))))
print(pres_loss1(torch.clone((maps))))

tensor(0.1175)
tensor(0.1175)


In [70]:
def orth_loss(num_parts: int, landmark_features: torch.Tensor, device) -> torch.Tensor:
    normed_feature = F.normalize(landmark_features, dim=1)
    similarity = torch.matmul(normed_feature.permute(0, 2, 1), normed_feature)
    similarity = torch.sub(similarity, torch.eye(num_parts + 1).to(device))
    loss_orth = torch.mean(torch.square(similarity))
    return loss_orth

def orthogonality_loss(num_parts: int, features: torch.Tensor) -> torch.Tensor:
    features = F.normalize(features, dim=1)
    similarities = features.permute(0, 2, 1) @ features
    similarities = similarities - torch.eye(num_parts + 1).to(features.device)
    return torch.mean(torch.square(similarities))

p, feats = 10, torch.ones(4, 1024, 11)
device = torch.device('cpu')
print(orth_loss(p, feats, device))
print(orth_loss1(p, feats, device))

tensor(0.9091)
tensor(0.9091)


In [128]:
def landmark_coordinates(maps: torch.Tensor, device: torch.device) -> \
        tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]),
                                    torch.arange(maps.shape[3]))
    grid_x = grid_x.unsqueeze(0).unsqueeze(0).to(device)
    grid_y = grid_y.unsqueeze(0).unsqueeze(0).to(device)

    map_sums = maps.sum(3).sum(2).detach()
    maps_x = grid_x * maps
    maps_y = grid_y * maps
    loc_x = maps_x.sum(3).sum(2) / map_sums
    loc_y = maps_y.sum(3).sum(2) / map_sums
    return loc_x, loc_y, grid_x, grid_y


def conc_loss(centroid_x: torch.Tensor,
              centroid_y: torch.Tensor,
              grid_x: torch.Tensor,
              grid_y: torch.Tensor,
              maps: torch.Tensor) -> torch.Tensor:
    spatial_var_x = ((centroid_x.unsqueeze(-1).unsqueeze(-1) - grid_x) / grid_x.shape[-1]) ** 2
    spatial_var_y = ((centroid_y.unsqueeze(-1).unsqueeze(-1) - grid_y) / grid_y.shape[-2]) ** 2
    spatial_var_weighted = (spatial_var_x + spatial_var_y) * maps
    loss_conc = spatial_var_weighted[:, 0:-1, :, :].mean()
    return loss_conc


def landmark_coordinates1(attn_maps: torch.Tensor):
    b, k, h, w = attn_maps.shape
    grid_x, grid_y = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy')
    
    grid_x = grid_x[None, None, ...].to(attn_maps.device)
    grid_y = grid_y[None, None, ...].to(attn_maps.device)

    attn_map_sums = attn_maps.sum((-1, -2)).detach()
    cx = torch.sum(grid_x * attn_maps, dim=(-1, -2)) / attn_map_sums
    cy = torch.sum(grid_y * attn_maps, dim=(-1, -2)) / attn_map_sums
    return cx, cy, grid_x, grid_y


def l_concentration(attn_maps: torch.Tensor) -> torch.Tensor:
    b, k, h, w = attn_maps.shape
    grid_x, grid_y = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy')
    
    grid_x = grid_x[None, None, ...].to(attn_maps.device)
    grid_y = grid_y[None, None, ...].to(attn_maps.device)

    attn_map_sums = attn_maps.sum((-1, -2)).detach()
    cx = torch.sum(grid_x * attn_maps, dim=(-1, -2)) / attn_map_sums
    cy = torch.sum(grid_y * attn_maps, dim=(-1, -2)) / attn_map_sums
    
    spatial_var_x = ((cx[..., None, None] - grid_x) / w) ** 2
    spatial_var_y = ((cy[..., None, None] - grid_y) / h) ** 2
    spatial_var_weighted = (spatial_var_x + spatial_var_y) * attn_maps
    return spatial_var_weighted[:, 0:-1, :, :].mean()

dummy_maps = torch.tensor([
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 1, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 1, 1 ,0],
    [0, 0, 0, 0, 0, 0, 1, 1, 1 ,1],
    [0, 0, 0, 0, 0, 0, 0, 1, 1 ,0],
    [0, 0, 0, 0, 0, 0, 0, 1, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
], dtype=torch.float32)

dummy_maps1 = torch.tensor([
    [0, 0, 0, 0, 0, 0, 0, 1, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [1, 1, 0, 0, 0, 0, 0, 0, 0 ,0],
    [1, 1, 0, 0, 0, 0, 0, 0, 0 ,1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 1, 0 ,0],
], dtype=torch.float32)

dummy_maps2 = torch.tensor([
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
    [1, 1, 0, 0, 0, 0, 0, 0, 0 ,0],
    [1, 1, 0, 0, 0, 0, 0, 0, 0 ,1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0 ,0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0 ,0],
], dtype=torch.float32)

dmp2 = dummy_maps[None, None, ...].expand(4, 6, -1, -1)
a, b, c, d = landmark_coordinates(dmp2, device='cpu')
print(conc_loss(a, b, c, d, dmp2))
print(conc_loss1(dmp2))

tensor(0.0018)
tensor(0.0018)


In [108]:
c

tensor([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
          [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
          [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
          [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
          [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
          [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
          [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]])

In [71]:
h, w = 14, 14
grid_x, grid_y = torch.meshgrid(torch.arange(h),torch.arange(w))
grid_x

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3],
        [ 4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4],
        [ 5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5],
        [ 6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6],
        [ 7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7],
        [ 8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8],
        [ 9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9],
        [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
        [11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11],
        [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],
        [13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13]])

In [73]:
h, w = 14, 14
grid_w, grid_h = torch.meshgrid(torch.arange(h),torch.arange(w), indexing='xy')
grid_w

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13]])

In [57]:
a = torch.randn(4, 1024, 10)
(a.permute(0, 2, 1) @ a).shape

torch.Size([4, 10, 10])

In [13]:
backbone.layer4[0].conv1.in_channels + backbone.fc.in_features

3072

In [19]:
a = torch.nn.Conv2d(1024 + 2048, 8 + 1, 1, bias=False)

In [21]:
torch.rand(1, 3, 224, 224).mean((-1,-2))

tensor([[0.4991, 0.4982, 0.5027]])

In [20]:
b, h, w = 4, 14, 14
torch.equal(a.weight.sum(1).unsqueeze(1).expand(-1, b, h, w).permute(1, 0, 2, 3), a.weight.sum(1).unsqueeze(0).expand(b, -1, h, w))

True

In [17]:
a.weight.sum(1).unsqueeze(0).expand(b, -1, h, w).shape

torch.Size([4, 9, 14, 14])

In [9]:
for k, v in result.items():
    print(k, v.shape)

layer3 torch.Size([1, 1024, 14, 14])
layer4 torch.Size([1, 2048, 7, 7])


In [7]:
backbone.avgpool

AdaptiveAvgPool2d(output_size=(1, 1))

In [3]:
import timm

model = timm.create_model('resnet101', pretrained=True)
model.bn1

BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [5]:
from omegaconf import OmegaConf
import torch.nn.functional as F
import torch

cfg = OmegaConf.load('configs/apn_CUB.yaml')
list(name.lower() for name in cfg.MODEL.LOSSES.keys())

['l_cls', 'l_reg', 'l_cpt']

In [31]:
from scipy import stats
def update_prior_dist(batch_size, alpha, beta):
    grid_points = torch.arange(1., 2*batch_size, 2.).float() / (2*batch_size)
    print(grid_points)
    grid_points_np = grid_points.cpu().numpy()
    grid_points_icdf = stats.beta.ppf(grid_points_np, a=alpha, b=beta)
    print(grid_points_icdf)
    prior_dist = torch.tensor(grid_points_icdf).float().unsqueeze(1)
    return prior_dist

In [32]:
prior = update_prior_dist(32, 1, 1e-3)

tensor([0.0156, 0.0469, 0.0781, 0.1094, 0.1406, 0.1719, 0.2031, 0.2344, 0.2656,
        0.2969, 0.3281, 0.3594, 0.3906, 0.4219, 0.4531, 0.4844, 0.5156, 0.5469,
        0.5781, 0.6094, 0.6406, 0.6719, 0.7031, 0.7344, 0.7656, 0.7969, 0.8281,
        0.8594, 0.8906, 0.9219, 0.9531, 0.9844])
[0.99999986 1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.        ]


In [27]:
import numpy as np
sum(prior)

tensor([32.])

In [23]:
prior.shape

AttributeError: 'torch.Size' object has no attribute 'shape'

In [12]:
import torch
b = 16
grid = torch.arange(1., 2*b, 2.).numpy() / 2 * b

In [13]:


stats.beta.ppf(grid, a=1, b=1e-3)

array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan])

In [8]:
x = torch.randn(4, 2048, 7,7)
F.conv2d(x, torch.randn(100, 2048, 1, 1)).shape

torch.Size([4, 100, 7, 7])

In [1]:
import torch
from omegaconf import OmegaConf
import  torch.nn.functional as f
from torchvision.models import resnet101, ResNet101_Weights

from apn import BackBone

In [8]:
import timm
model = timm.create_model('resnet101', pretrained=True)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act2): ReLU(inplace=True)
      (aa): Identity()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

In [6]:
import os
from pathlib import Path
os.path.basename('configs/resnet101_ft_CUB.yaml')

'backbone_ft_CUB.yaml'

In [7]:
Path('configs/resnet101_ft_CUB.yaml').stem

'backbone_ft_CUB'

In [3]:
backbone = BackBone('resnet101', num_classes=200)

In [4]:
set(backbone.state_dict().keys()) == set(resnet101(weights=ResNet101_Weights.DEFAULT).state_dict().keys())

True

In [26]:
conf = OmegaConf.load('configs/resnet101_ft_CUB.yaml')

In [30]:
type(conf.MODEL.LR)

float

In [18]:
model = resnet101(weights=ResNet101_Weights.DEFAULT)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [19]:
model.fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=200)

In [20]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 