In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [1]:
import torch
import clip
import os
import clip
import torch
import torch.nn as nn
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm
from PIL import Image
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.transforms.functional import InterpolationMode
import numpy as np
%matplotlib inline


In [2]:
image_size = 384

preprocess = transforms.Compose([
    transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ]) 

train_ds = datasets.CIFAR10('/home/zhilif/incremental_learn/FractalDB-Pretrained-ResNet-PyTorch/data', train=True, transform=preprocess, download=False)
test_ds = datasets.CIFAR10('/home/zhilif/incremental_learn/FractalDB-Pretrained-ResNet-PyTorch/data', train=False, transform=preprocess, download=False)

trainLoader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=100,
    shuffle=True, pin_memory=False)

testLoader = torch.utils.data.DataLoader(
    test_ds,
    batch_size=100,
    shuffle=False, pin_memory=True)

In [3]:
from models.blip import blip_decoder


model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
device = "cuda" if torch.cuda.is_available() else "cpu"

model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

reshape position embedding from 196 to 576
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth


In [5]:
trainLoader.dataset.classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [6]:
x, y = next(iter(trainLoader))
# x = x.to(device)
# y = y.to(device)
# for i in range(20):
#     with torch.no_grad():
#         # beam search
#         caption = model.generate(x[i:i+1], sample=False, num_beams=3, max_length=20, min_length=5) 
#         # nucleus sampling
#         # caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) 
#         print('caption: '+ caption[0], trainLoader.dataset.classes[y[i].item()])
#         plt.imshow(x[i].permute(1,2,0).cpu().numpy())
#         plt.show()

In [7]:
# mean_dict = dict()
# count_dict = dict()
# for i, (X, Y) in tqdm(enumerate(trainLoader)):
#     for j in range(len(X)):
#         if Y[j].item() not in mean_dict:
#             mean_dict[Y[j].item()] = X[j]
#             count_dict[Y[j].item()] = 1
#         else:
#             mean_dict[Y[j].item()] = mean_dict[Y[j].item()] * (count_dict[Y[j].item()] / (count_dict[Y[j].item()]+1)) + X[j] / (count_dict[Y[j].item()]+1)
#             count_dict[Y[j].item()] += 1
# #             print(Y[j], mean_dict[Y[j].item()].shape)

mean_dict = dict()
count_dict = dict()
for i, (X, Y) in tqdm(enumerate(trainLoader)):
    with torch.no_grad():
        X = model.visual_encoder(X.to(device))
        for j in range(len(X)):
            if Y[j].item() not in mean_dict:
                mean_dict[Y[j].item()] = X[j]
                count_dict[Y[j].item()] = 1
            else:
                mean_dict[Y[j].item()] = mean_dict[Y[j].item()] * (count_dict[Y[j].item()] / (count_dict[Y[j].item()]+1)) + X[j] / (count_dict[Y[j].item()]+1)
                count_dict[Y[j].item()] += 1
    #             print(Y[j], mean_dict[Y[j].item()].shape)
    

500it [10:52,  1.31s/it]


In [4]:
def generate_with_embedding(model, image_embeds, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
    batch_size = image_embeds.shape[0]
    if not sample:
        image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
    image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image_embeds.device)
    model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}

    prompt = [model.prompt] * batch_size
    input_ids = model.tokenizer(prompt, return_tensors="pt").input_ids.to(image_embeds.device) 
    input_ids[:,0] = model.tokenizer.bos_token_id
    input_ids = input_ids[:, :-1] 

    if sample:
        #nucleus sampling
        outputs = model.text_decoder.generate(input_ids=input_ids,
                                              max_length=max_length,
                                              min_length=min_length,
                                              do_sample=True,
                                              top_p=top_p,
                                              num_return_sequences=1,
                                              eos_token_id=model.tokenizer.sep_token_id,
                                              pad_token_id=model.tokenizer.pad_token_id, 
                                              repetition_penalty=1.1,                                            
                                              **model_kwargs)
    else:
        #beam search
        outputs = model.text_decoder.generate(input_ids=input_ids,
                                              max_length=max_length,
                                              min_length=min_length,
                                              num_beams=num_beams,
                                              eos_token_id=model.tokenizer.sep_token_id,
                                              pad_token_id=model.tokenizer.pad_token_id,     
                                              repetition_penalty=repetition_penalty,
                                              **model_kwargs)            

    captions = []    
    for output in outputs:
        caption = model.tokenizer.decode(output, skip_special_tokens=True)    
        captions.append(caption[len(model.prompt):])
    return captions


# for k, v in mean_dict.items():
#     with torch.no_grad():
#         # beam search
#         caption = generate_with_embedding(model, (v[None, ...]*np.random.rand()).to(device), sample=False, num_beams=3, max_length=20, min_length=5) 
#         # nucleus sampling
#         # caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) 
#         print('caption: '+ caption[0], trainLoader.dataset.classes[k])
# #         plt.imshow(v.permute(1,2,0).cpu().numpy())
# #         plt.show()

In [10]:
for k1, v1 in mean_dict.items():
    for k2, v2 in mean_dict.items():
        if (k1==k2):
            continue
        with torch.no_grad():
            # beam search
            caption = generate_with_embedding(model, (v1-v2)[None, ...].to(device), sample=False, num_beams=3, max_length=20, min_length=5) 
            # nucleus sampling
            # caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) 
            print(f'{trainLoader.dataset.classes[k1]} - {trainLoader.dataset.classes[k2]}')
            print('caption: '+ caption[0])

bird - frog
caption: a bird standing on a pole
bird - deer
caption: a woman in a blue dress
bird - truck
caption: a bird bird - bird - bird - bird - bird - bird - bird -
bird - automobile
caption: a bird bird bird - nature nature bird bird bird bird bird bird bird bird bird
bird - cat
caption: a bird in the sky
bird - airplane
caption: a black - crested crested crested crested crested crested crested crested crested crested crested crested crested
bird - ship
caption: a young - brown - crested - crested - crested - crested - crested - crested
bird - horse
caption: a bird's face
bird - dog
caption: a bird in the sky
frog - bird
caption: a man's face, with the skin and eyes
frog - deer
caption: a woman's face with a black and white background
frog - truck
caption: a close a a a a a a a a a a a a a a
frog - automobile
caption: a close a a a a a a a a a a a a a a
frog - cat
caption: a plant with a green background
frog - airplane
caption: a green leopard leopard leopard leopard leopard leo

In [19]:
trainLoader.dataset.classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [5]:



# nc is the number of samples that have been seen
# mu_x is the mean of the seen samples
# new_data should be of shape (bsz, num_patches, embed_dim)
# all data should belong to the same class
def updateStat(nc, mu_x, cov_x, new_data):
    lc = new_data.shape[0]
    mu_y = new_data.mean(0)
    
    cov_y = torch.einsum('ijk, ijh->jkh', new_data-mu_y, new_data-mu_y)
    cov_update = cov_x + cov_y + (nc**2*lc+lc**2*nc)/((nc+lc)**2)*(torch.einsum('ij,ik->ijk', mu_y-mu_x,mu_y-mu_x))
    
    return (mu_y*lc + mu_x*nc)/(nc+lc), cov_update, nc+lc


# (num_patch, embed_dim) is the shape of the vision transformer output. Notice that 
# ViT chops images into patches, and applies transformation to each patch. 
# Usually num_patch = 1 + number of patches, and the first dimension is extracted
# for classification task (such as CLIP does).
# For base ViT, num_patch = 577, embed_dim=768
# trainLoader contain sample from the same class
def getStat(trainLoader, model, num_patch=577, embed_dim=768):
    nc = 0
    mu_x = torch.zeros(num_patch, embed_dim).to(device)
    cov_x = torch.zeros(num_patch, embed_dim, embed_dim).to(device)
    for i, (X, Y) in tqdm(enumerate(trainLoader)):
        with torch.no_grad():
            feature = model.visual_encoder(X.to(device))
            mu_x, cov_x, nc = updateStat(nc, mu_x, cov_x, feature)
    return mu_x, cov_x, nc

In [6]:
from torch.utils.data import Subset
num_classes = 10
trainLoaders = [[] for i in range(num_classes)]

for k in range(num_classes):
    trainSubset = Subset(train_ds, (torch.tensor(train_ds.targets)==k).nonzero().squeeze()[0:2000])
    trainLoaders[k] =torch.utils.data.DataLoader(
        trainSubset,
        batch_size=20,
        shuffle=False, pin_memory=False)

In [7]:
mu_x0, cov_x0, nc0 = getStat(trainLoaders[4], model, num_patch=577, embed_dim=768)

100it [00:29,  3.45it/s]


In [8]:
mu_x1, cov_x1, nc1 = getStat(trainLoaders[5], model, num_patch=577, embed_dim=768)

100it [00:28,  3.49it/s]


In [7]:
for k in range(num_classes):
    trainLoader = trainLoaders[k]
    for i, (x, y) in enumerate(trainLoader):
        with torch.no_grad():
            v = model.visual_encoder(x.to(device)).mean(0)
            caption = generate_with_embedding(model, v[None, ...].to(device), sample=False, num_beams=3, max_length=20, min_length=5) 
            print('caption: '+ caption[0])

caption: flight flight flight flight flight flight flight flight flight flight flight flight flight flight flight flight
caption: car car car car car car car car car car car car car car car car
caption: bird bird bird bird bird bird bird bird bird bird bird bird bird bird bird bird


KeyboardInterrupt: 

In [6]:
# visual_embeds = []
# for k in range(num_classes):
#     for i, (x, y) in enumerate(trainLoaders[k]):
#         with torch.no_grad():
#             if (len(visual_embeds)<=k):
#                 visual_embeds.append([model.visual_encoder(x.to(device))])
#             else:
#                 visual_embeds[k].append(model.visual_encoder(x.to(device)))

visual_embeds = []
for k in range(num_classes):
    x, y = next(iter(trainLoaders[k]))
    with torch.no_grad():
        visual_embeds.append(model.visual_encoder(x.to(device)))


In [10]:
for k in range(num_classes):
    for j in range(num_classes):
        if k==j:
            continue
        v1 = visual_embeds[k].mean(0)
        v2 = visual_embeds[j].mean(0)
        caption = generate_with_embedding(model, (v1-v2)[None, ...].to(device), sample=False, num_beams=3, max_length=20, min_length=5) 
        print('caption: '+ caption[0], (k, j))

caption: a plane in the sky (0, 1)
caption: an airplane in the sky (0, 2)
caption: the sky aero aero aero aero aero aero aero aero aero aero aero aero aero aero (0, 3)
caption: an airplane in the sky (0, 4)
caption: the blue sky sky sky sky sky an an an an an an an an an (0, 5)
caption: an airplane in the sky sky sky sky sky sky sky sky sky sky sky sky (0, 6)
caption: a plane in the sky (0, 7)
caption: a man in a black and white photo of a man in a black and white (0, 8)
caption: a plane flying in the sky (0, 9)
caption: a car'car'''''''''''' (1, 0)
caption: the car stock car stock car stock cars cars cars cars cars cars cars cars cars (1, 2)
caption: the $ $ $ $ $ $ $ $ $ $ $ $ $ $ $ (1, 3)
caption: the car nissan nissan nissan nissan nissan nissan nissan nissan nissan nissan nissan nissan nissan nissan (1, 4)
caption: the car'car'car'car'car'car'car'car (1, 5)
caption: the car in the car in the car in the car in the car in the (1, 6)
caption: the mazda mazda mazda mazda mazda mazda m

In [7]:
v0 = visual_embeds[2]
v1 = visual_embeds[3]

mu0 = v0.mean(0)
mu1 = v1.mean(0)

print(mu0.shape, v0.shape, v1.shape, mu1.shape)

v0_ = v0 - mu0
v1_ = v1 - mu1

Sigma0 = torch.einsum('ijk,ijh->jkh', v0_, v0_) / (v0_.shape[0]-1)
Sigma1 = torch.einsum('ijk,ijh->jkh', v1_, v1_) / (v1_.shape[0]-1)

torch.Size([577, 768]) torch.Size([100, 577, 768]) torch.Size([100, 577, 768]) torch.Size([577, 768])


In [1]:
w = []
lam = 0
for i in tqdm(range(577)):
#     w.append(torch.linalg.inv((1-lam)*(cov_x0[i]+cov_x1[i]) + lam * torch.eye(cov_x0[i].shape[1]).to(device))@(mu_x1[i]-mu_x0[i]))
    w.append(torch.linalg.solve((1-lam)*(cov_x0[i]+cov_x1[i]) + lam * torch.eye(cov_x0[i].shape[1]).to(device), (mu_x1[i]-mu_x0[i])))
h = torch.stack(w)

NameError: name 'tqdm' is not defined

In [9]:
a = cov_x0[3]+cov_x1[3]
b = torch.linalg.inv(a)
torch.linalg.matrix_rank(a)

tensor(764, device='cuda:0')

In [11]:
caption = generate_with_embedding(model, (h[None, ...]).to(device), sample=False, num_beams=10, max_length=20, min_length=5) 
print('caption: '+ caption[0])

caption: a person standing in front of a building


In [7]:
a=torch.rand(2,5)
b = a.T@a
b.shape

torch.Size([5, 5])