In [1]:
import os
import clip
import torch
import argparse
import numpy as np
import pandas as pd
from torch import nn
from torchvision import transforms as T

from models.lm4cv import mean_mahalanobis_distance

In [2]:
D = 64
K = 100
N = 1000
T = torch.randn(N, D)
E = torch.randn(K, D)

In [3]:
embedder, preprocess = clip.load('ViT-B/32', device='cpu')
raw_concepts = open("data/LM4CV/cub_attributes.txt", 'r').read().strip().split("\n")

full_concept_emb = []   # Matrix T
batch_size = 32

prompt_prefix = 'The bird has '
num_batches = len(raw_concepts) // batch_size + 1
for i in range(num_batches):
    batch_concepts = raw_concepts[i * batch_size: (i + 1) * batch_size]
    batch_concept_emb = clip.tokenize([prompt_prefix + attr for attr in batch_concepts])
    full_concept_emb.append(embedder.encode_text(batch_concept_emb).detach().cpu())

full_concept_emb = torch.concat(full_concept_emb).float()
full_concept_emb = full_concept_emb / full_concept_emb.norm(dim=-1, keepdim=True)   # Matrix T

In [21]:
def mahalanobis_distance(x, mu, sigma_inv):
    x = x - mu.unsqueeze(0)
    print('x', x.size())
    print('x @ sigma_inv @ x.T', x @ sigma_inv @ x.T)
    print('torch.diag(x @ sigma_inv @ x.T)', torch.diag(x @ sigma_inv @ x.T))
    return torch.diag(x @ sigma_inv @ x.T).mean()

In [6]:
mu = torch.mean(full_concept_emb, dim=0)
sigma_inv = torch.tensor(np.linalg.inv(torch.cov(full_concept_emb.T)))

# for embed in full_concept_emb:
#     print(mahalanobis_distance(embed, mu, sigma_inv))
np.mean([mahalanobis_distance(embed, mu, sigma_inv) for embed in full_concept_emb])

510.35193

In [23]:
mahalanobis_distance(full_concept_emb[0], mu, sigma_inv)

x torch.Size([1, 512])
x @ sigma_inv @ x.T tensor([[611.2968]])
torch.diag(x @ sigma_inv @ x.T) tensor([611.2968])


tensor(611.2968)

In [19]:
(full_concept_emb[0] - mu.unsqueeze(0)).size()

torch.Size([1, 512])

In [None]:
t_mu = torch.mean(T, dim=0)
t_sigma_inv = torch.linalg.inv(torch.cov(T.T))
mean_distance = torch.tensor([mahalanobis_distance(t_row, t_mu, t_sigma_inv) for t_row in T]).mean()
mean_distance

In [None]:
t_sigma_inv.size()

In [None]:
mahalanobis_loss = (mahalanobis_distance(E / torch.linalg.norm(E, dim=-1, keepdim=True), t_mu, t_sigma_inv) - mean_distance) / (mean_distance ** 3)
mahalanobis_loss

In [None]:
mahalanobis_distance(E / torch.linalg.norm(E, dim=-1, keepdim=True), t_mu, t_sigma_inv)

In [None]:
mean_distance

In [None]:
T[0].unsqueeze(0) @ t_sigma_inv @ T[0].T

In [None]:
torch.diag(T[0].unsqueeze(0) @ t_sigma_inv @ T[0].T).mean()