In [1]:
import os
import clip
import torch
import argparse
import numpy as np
import pandas as pd
from torch import nn
from pprint import pprint
from torchvision import transforms as T
from datasets.cub_dataset import CUBDataset

In [2]:
cos = nn.CosineSimilarity()
a = torch.randn(100, 10)
b = torch.randn(10)
res = cos(b, a)
res.size()

torch.Size([100])

# Test implementation of loss function

Generated concept embeddings

In [3]:
torch.manual_seed(42)
np.random.seed(42)
embedder, preprocess = clip.load('ViT-B/32', device='cpu')
raw_concepts = open("data/LM4CV/cub_attributes.txt", 'r').read().strip().split("\n")

full_concept_emb = []   # Matrix T
batch_size = 32

prompt_prefix = 'The bird has '
num_batches = len(raw_concepts) // batch_size + 1
for i in range(num_batches):
    batch_concepts = raw_concepts[i * batch_size: (i + 1) * batch_size]
    batch_concept_emb = clip.tokenize([prompt_prefix + attr for attr in batch_concepts])
    full_concept_emb.append(embedder.encode_text(batch_concept_emb).detach().cpu())

full_concept_emb = torch.concat(full_concept_emb).float()
full_concept_emb = full_concept_emb / full_concept_emb.norm(dim=-1, keepdim=True)   # Matrix T

Init some random values

In [4]:
D = 512
K = 100
N = 1000
T = torch.randn(N, D)
E = torch.randn(K, D)

New Implementation

In [5]:
def _mean_squared_mahalanobis(x: torch.Tensor, mu: torch.Tensor, sigma_inv):
    '''Computes the mean of squared mahalanobis distances from a vector or a set of vectors to the distribution
    with mean my and and inverse covariant matrix sigma_inv.
    Implementation from https://github.com/wangyu-ustc/LM4CV/blob/main/utils/train_utils.py#L263
    
    Args:
        x (Tensor[M, D]) or (Tensor[D]): a vector or a set of vector of length D.
        distribution (Tensor[N, D]) a matrix of N vectors of length D
    
    Returns:
        Tensor[]: a scaler tensor, which is the mahalanobis distance from vec to the distribution.
    '''
    delta = x - mu.unsqueeze(0)
    return torch.diag(delta @ sigma_inv @ delta.T).mean()


class Stage1Criterion(nn.Module):
    def __init__(self, regularization=True, division_power=3) -> None:
        super().__init__()
        self.xe = nn.CrossEntropyLoss()
        self.regularization = regularization
        self.division_power = division_power
    
    def forward(self, outputs: torch.Tensor, targets, weights, full_concept_emb):
        # xe_loss = self.xe(outputs, targets)
        # if not self.regularization:
        #     return xe_loss

        # Original implementation from https://github.com/wangyu-ustc/LM4CV/blob/main/utils/train_utils.py#L208
        # which is different to the one described in the paper.
        weights_norm = torch.linalg.norm(weights, dim=-1, keepdim=True)
        mu = torch.mean(full_concept_emb, dim=0)
        sigma_inv = torch.tensor(np.linalg.inv(torch.cov(full_concept_emb.T)))    # Using torch.inverse will have different result
        # Alternate implementation: sigma_inv = torch.inverse(torch.cov(distribution.T))

        mean_distance = torch.stack([_mean_squared_mahalanobis(embed, mu, sigma_inv)
                                     for embed
                                     in full_concept_emb]).mean().to(outputs.device)

        mahalanobis_loss = _mean_squared_mahalanobis(weights / weights_norm, mu, sigma_inv)
        mahalanobis_loss_scaled = (mahalanobis_loss - mean_distance) / (mean_distance ** self.division_power)

        return torch.abs(mahalanobis_loss_scaled)

In [6]:
loss_layer = Stage1Criterion()
loss_layer(torch.tensor([]), torch.tensor([]), E, full_concept_emb)

tensor(0.0946)

Implementation from original paper

In [7]:
attribute_embeddings = full_concept_emb

def mahalanobis_distance(x, mu, sigma_inv):
    x = x - mu.unsqueeze(0)
    return torch.diag(x @ sigma_inv @ x.T).mean()

model = [E]

mu = torch.mean(attribute_embeddings, dim=0)
sigma_inv = torch.tensor(np.linalg.inv(torch.cov(attribute_embeddings.T)))
configs = {
    'mu': mu,
    'sigma_inv': sigma_inv,
    'mean_distance': np.mean([mahalanobis_distance(embed, mu, sigma_inv) for embed in attribute_embeddings])
}

mahalanobis_loss = (mahalanobis_distance(E/E.norm(dim=-1, keepdim=True), configs['mu'],
                    configs['sigma_inv']) - configs['mean_distance']) / (configs['mean_distance']**3)
torch.abs(mahalanobis_loss)

tensor(0.0946)

In [12]:
torch.nn.Linear(100,200).weight.data.dtype

torch.float32

In [18]:
torch.load('data/CUB_200_2011/train_images_encoded.pt').dtype

torch.float32

In [8]:
torch.load('data/CUB_200_2011/train_filename2idx.pt')

['001.Black_footed_Albatross/Black_Footed_Albatross_0009_34.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0074_59.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0014_89.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0031_100.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0051_796103.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0010_796097.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0023_796059.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0040_796066.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0089_796069.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0067_170.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0060_796076.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0056_796078.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0080_796096.jpg',
 '001.Black_footed_Albatross/Black_Footed_Albatross_0047_796064.jpg',
 '001.Black_footed_Albatross/Black_Foo

In [25]:
cos = torch.nn.CosineSimilarity(dim=-1)

In [27]:
cos(torch.ones(512), torch.ones(1105, 512))

tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000])

In [22]:
torch.ones(10).size(), torch.ones(100, 10).size()

(torch.Size([10]), torch.Size([100, 10]))