In [24]:
import os
import clip
import torch
import argparse
import numpy as np
import pandas as pd
from torch import nn
from torchvision import transforms as T

from models.lm4cv import mean_mahalanobis_distance

In [26]:
embedder, preprocess = clip.load('ViT-B/32', device='cpu')
raw_concepts = open("data/LM4CV/cub_attributes.txt", 'r').read().strip().split("\n")

full_concept_emb = []   # Matrix T
batch_size = 32

prompt_prefix = 'The bird has '
num_batches = len(raw_concepts) // batch_size + 1
for i in range(num_batches):
    batch_concepts = raw_concepts[i * batch_size: (i + 1) * batch_size]
    batch_concept_emb = clip.tokenize([prompt_prefix + attr for attr in batch_concepts])
    full_concept_emb.append(embedder.encode_text(batch_concept_emb).detach().cpu())

full_concept_emb = torch.concat(full_concept_emb).float()
full_concept_emb = full_concept_emb / full_concept_emb.norm(dim=-1, keepdim=True)   # Matrix T

In [42]:
D = 512
K = 100
N = 1000
T = torch.randn(N, D)
E = torch.randn(K, D)

In [45]:
# Implementation 1
def mahalanobis_distance(x, mu, sigma_inv):
    x = x - mu.unsqueeze(0)
    return torch.diag(x @ sigma_inv @ x.T).mean()

mu = torch.mean(full_concept_emb, dim=0)
sigma_inv = torch.tensor(np.linalg.inv(torch.cov(full_concept_emb.T)))
mean_distance = np.mean([mahalanobis_distance(embed, mu, sigma_inv) for embed in full_concept_emb])
mean_distance

510.35193

In [47]:
e_norm = torch.linalg.norm(E, dim=-1, keepdim=True)
mahalanobis_loss = (mahalanobis_distance(E / e_norm, mu, sigma_inv) - mean_distance) / (mean_distance ** 3)
mahalanobis_loss

tensor(-0.0358)

In [49]:
# Implementation 2
def mahalanobis_distance2(x, mu, sigma_inv):
    x = x - mu
    return torch.sqrt(x @ sigma_inv @ x.T)

torch.tensor([mahalanobis_distance2(e_row, mu, sigma_inv) for e_row in E]).mean()

tensor(nan)

In [86]:
def is_psd(mat):
    return bool((mat == mat.T).all() and (torch.linalg.eigvals(mat).real>=0).all())

In [90]:
is_psd(torch.cov(full_concept_emb.T))

False

In [93]:
def mahalanobis(u, v, cov):
    delta = u - v
    m = torch.dot(delta, torch.matmul(torch.inverse(cov), delta))
    return m
    # return torch.sqrt(m)

In [94]:
def mahalanobis_distance2(x, mu, sigma_inv):
    x = x - mu
    return x @ sigma_inv @ x.T

mu = torch.mean(full_concept_emb, dim=0)

sigma = torch.cov(full_concept_emb.T)
sigma_inv = torch.linalg.inv(torch.cov(full_concept_emb.T))

distances = [mahalanobis(e_row, mu, sigma) for e_row in E]
distances

[tensor(6.4656e+09),
 tensor(-4.8790e+09),
 tensor(8.4279e+09),
 tensor(-6.2947e+09),
 tensor(-2.2306e+10),
 tensor(2.2523e+10),
 tensor(2.6177e+09),
 tensor(3.9648e+10),
 tensor(-2.6618e+09),
 tensor(2.6030e+09),
 tensor(-7.3518e+09),
 tensor(3.6626e+08),
 tensor(5.9266e+08),
 tensor(2.4756e+10),
 tensor(6.8798e+09),
 tensor(5.3886e+09),
 tensor(6.8810e+08),
 tensor(1.7023e+10),
 tensor(1.0625e+10),
 tensor(-1.9214e+10),
 tensor(2.1618e+10),
 tensor(-2.0238e+09),
 tensor(1.2442e+10),
 tensor(8.0594e+08),
 tensor(-4.4186e+09),
 tensor(5.0296e+09),
 tensor(1.7522e+10),
 tensor(-1.6372e+09),
 tensor(2.4280e+09),
 tensor(9.8822e+09),
 tensor(-3.6955e+10),
 tensor(5.1362e+09),
 tensor(7.4042e+08),
 tensor(-3.5569e+09),
 tensor(-5.0281e+09),
 tensor(3.9498e+09),
 tensor(-8.1825e+09),
 tensor(-2.8514e+10),
 tensor(3.7873e+09),
 tensor(8.2648e+09),
 tensor(1.2701e+10),
 tensor(-8.7634e+09),
 tensor(-9.5503e+09),
 tensor(-5.9256e+09),
 tensor(7.4480e+09),
 tensor(1.6485e+10),
 tensor(5.1022e+1

In [64]:
torch.cov(full_concept_emb.T)

tensor([[ 1.3909e-04,  1.0566e-05, -4.4615e-06,  ..., -4.1785e-05,
          1.5213e-05, -1.6853e-05],
        [ 1.0566e-05,  2.7116e-04,  5.5303e-05,  ...,  5.3137e-05,
         -9.1934e-05, -2.8730e-05],
        [-4.4615e-06,  5.5303e-05,  1.4856e-04,  ...,  5.7498e-05,
         -2.5527e-05, -2.8638e-05],
        ...,
        [-4.1785e-05,  5.3137e-05,  5.7498e-05,  ...,  5.3831e-04,
         -1.3482e-04,  2.9747e-05],
        [ 1.5213e-05, -9.1934e-05, -2.5527e-05,  ..., -1.3482e-04,
          2.2928e-04, -3.7178e-06],
        [-1.6853e-05, -2.8730e-05, -2.8638e-05,  ...,  2.9747e-05,
         -3.7178e-06,  2.2496e-04]])

In [56]:
sigma_inv = torch.linalg.inv(torch.cov(full_concept_emb.T))
sigma_inv

tensor([[ 5248255.0000,  5302723.0000,  -696952.1875,  ...,
          6577713.5000,  4749001.0000, -2328161.2500],
        [ 5299397.0000,  6869843.0000,  -698490.3750,  ...,
         10132456.0000,  6002626.5000, -2560786.5000],
        [ -711532.6250,  -716567.8750,  1068546.5000,  ...,
          2252066.5000,  -300771.9688,  -500953.2812],
        ...,
        [ 6409478.5000,  9936476.0000,  2330944.0000,  ...,
         45804552.0000, 11349803.0000, -6523319.5000],
        [ 4721849.5000,  5973281.5000,  -276246.5000,  ...,
         11499574.0000,  6372226.5000, -3210215.5000],
        [-2303082.2500, -2531706.7500,  -514123.4688,  ...,
         -6549449.5000, -3192587.7500,  2575318.0000]])

In [59]:
torch.cov(full_concept_emb.T)

tensor([[ 1.3909e-04,  1.0566e-05, -4.4615e-06,  ..., -4.1785e-05,
          1.5213e-05, -1.6853e-05],
        [ 1.0566e-05,  2.7116e-04,  5.5303e-05,  ...,  5.3137e-05,
         -9.1934e-05, -2.8730e-05],
        [-4.4615e-06,  5.5303e-05,  1.4856e-04,  ...,  5.7498e-05,
         -2.5527e-05, -2.8638e-05],
        ...,
        [-4.1785e-05,  5.3137e-05,  5.7498e-05,  ...,  5.3831e-04,
         -1.3482e-04,  2.9747e-05],
        [ 1.5213e-05, -9.1934e-05, -2.5527e-05,  ..., -1.3482e-04,
          2.2928e-04, -3.7178e-06],
        [-1.6853e-05, -2.8730e-05, -2.8638e-05,  ...,  2.9747e-05,
         -3.7178e-06,  2.2496e-04]])

In [61]:
np.cov(np.array(full_concept_emb.T))

array([[ 1.39089103e-04,  1.05659480e-05, -4.46153526e-06, ...,
        -4.17851290e-05,  1.52127005e-05, -1.68525966e-05],
       [ 1.05659480e-05,  2.71160301e-04,  5.53028536e-05, ...,
         5.31374681e-05, -9.19336796e-05, -2.87299555e-05],
       [-4.46153526e-06,  5.53028536e-05,  1.48562504e-04, ...,
         5.74976233e-05, -2.55273265e-05, -2.86382370e-05],
       ...,
       [-4.17851290e-05,  5.31374681e-05,  5.74976233e-05, ...,
         5.38313309e-04, -1.34823570e-04,  2.97474284e-05],
       [ 1.52127005e-05, -9.19336796e-05, -2.55273265e-05, ...,
        -1.34823570e-04,  2.29284547e-04, -3.71777220e-06],
       [-1.68525966e-05, -2.87299555e-05, -2.86382370e-05, ...,
         2.97474284e-05, -3.71777220e-06,  2.24960100e-04]])