In [1]:
import os
os.environ['http_proxy'] = "http://192.41.170.23:3128" 
os.environ['https_proxy'] = "http://192.41.170.23:3128" 

In [2]:
# !pip install pycocotools
# !pip install faiss-cpu faiss-gpu
# !pip install nltk
# !pip install salesforce-lavis

In [3]:
import model.clip as clip
from model.model import ResidualAttentionBlock
from transformers import CLIPProcessor, CLIPModel
import numpy as np
import torch
from torch import nn
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, PILToTensor
import cv2
from captum.attr import visualization
from tqdm import tqdm
from transformers import ViTForImageClassification
import timm
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import json
from faiss import write_index, read_index
from torch.utils.data import DataLoader
import torch
from torchvision.datasets import CocoCaptions

from collections import OrderedDict
from datasets import load_dataset
import faiss
import gc
from typing import Any, Tuple, Callable, Optional, List
from sklearn.metrics import recall_score

# from lavis.models import model_zoo, load_model_and_preprocess

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

cuda_card = 1

In [4]:
model, preprocess = clip.load("ViT-L/14", get_all_token=False)
model.cuda(cuda_card).eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 149,620,737
Input resolution: 224
Context length: 77
Vocab size: 49408


In [5]:
def normalize_vector(arr):
    norms = torch.linalg.norm(arr, axis=1, keepdims=True)
    return arr / norms

In [6]:
class CocoCustom(CocoCaptions): 
    def __init__(
        self,
        root: str,
        annFile: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, annFile, transform, target_transform, transforms)
        from pycocotools.coco import COCO

        self.annotations = json.load(open(annFile))
        self.num_captions = len(self.annotations['annotations'])

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        id = self.ids[index]
        image = self._load_image(id)
        target = self._load_target(id)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, id, target

    def getAnnotationRange(self, index: int, count: int) -> List[Any]:
        return [self.annotations['annotations'][index]['caption'] for index in range(index, index+count)]

    def getImgIdFromAnnotationIndex(self, annotation_index: int) -> int:
        return self.annotations['annotations'][annotation_index]['image_id']
    
    def buildFaissIndex(self, text_encoder, tokenize, batch_size, nlist) :
        tokenized = tokenize(self.getAnnotationRange(0, batch_size)).cuda(cuda_card)
        encoded_captions = normalize_vector(text_encoder(tokenized, get_all_token=False).detach().cpu().numpy().astype('float32'))
        vector_dimension = encoded_captions.shape[1]
        
        quantizer = faiss.IndexFlatIP(vector_dimension)
        index = faiss.IndexIVFFlat(quantizer, vector_dimension, nlist)
        index.train(encoded_captions)
        index.add(encoded_captions)
        
        for i in tqdm(range(batch_size, self.num_captions - batch_size, batch_size)):
            tokenized = clip.tokenize(self.getAnnotationRange(i, batch_size)).cuda(cuda_card)
            encoded_captions = normalize_vector(model.encode_text(tokenized, get_all_token=False).detach().cpu().numpy().astype('float32'))
            index.add(encoded_captions)

        return index

    def __len__(self) -> int:
        return len(self.ids)


In [7]:
path = '../../Dataset/CV/mscoco/2017'
cocoCaptions = CocoCustom(root = path + '/val2017',
                        annFile = path + '/annotations/captions_val2017.json',
                        transform=preprocess)

print('Number of samples: ', len(cocoCaptions))
img, img_id, target = cocoCaptions[3]

print("Image Size:", img.size())
print("Captions:", target)
print("Image Id:", img_id)

loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
Number of samples:  5000
Image Size: torch.Size([3, 224, 224])
Captions: ["A stop sign is mounted upside-down on it's post. ", 'A stop sign that is hanging upside down.', 'An upside down stop sign by the road.', 'a stop sign put upside down on a metal pole ', 'A stop sign installed upside down on a street corner']
Image Id: 724


In [8]:
caption_map = {}
captions = []
caption_idx = 0
images_id_list = []
for i in tqdm(range(len(cocoCaptions))):
    img, img_id, target = cocoCaptions[i]
    images_id_list.append(img_id)
    for idx in range(len(target)):
        capt = target[idx]
        caption_map[caption_idx] = img_id
        caption_idx += 1
        captions.append(capt)

100%|██████████| 5000/5000 [06:07<00:00, 13.61it/s]


In [9]:
device = torch.device('cuda:' + str(cuda_card))
batch_size = 32
encoded_captions = []

for i in tqdm(range(0, len(captions), batch_size)):
    text = captions[i: min(len(captions), i+batch_size)]
    tokenized = clip.tokenize(text).to(device)
    encoded = normalize_vector(model.encode_text(tokenized, get_all_token=False).detach())
    encoded_captions.append(encoded)
encoded_captions = torch.concat(encoded_captions)

100%|██████████| 782/782 [00:49<00:00, 15.73it/s]


In [10]:
torch.cuda.empty_cache()

In [11]:
def collate_coco(data):
    images = torch.stack([image for image, image_id, text in data])
    images_id = torch.tensor([image_id for image, image_id, text in data])
    return images, images_id

batch_size = 50
data_loader = DataLoader(cocoCaptions, collate_fn=collate_coco, batch_size=batch_size)

In [15]:
model.eval()

predicted = []
score_matrix_i2t = []


with torch.no_grad():
    for images, images_id in tqdm(data_loader):
        images = images.cuda(cuda_card)
        image_encodes = normalize_vector(model.encode_image(images).detach())
        similarity_score = image_encodes.matmul(encoded_captions.T)
        score_matrix_i2t.append(similarity_score)
        # _, indexes = faissIndex.search(normalize_vector(image_encodes.cpu().numpy().astype(np.float32)), 1)
        # predicted_image_id = [caption_map[int(predicted.cpu())] for predicted in torch.argmax(similarity_score, dim=1)]
        # predicted.extend(predicted_image_id)
        # ground_truth.extend(images_id.tolist())
score_matrix_i2t = torch.concat(score_matrix_i2t).cpu()

100%|██████████| 100/100 [05:45<00:00,  3.45s/it]


In [16]:
img2txt = dict([(value, []) for value in images_id_list])
for key, value in caption_map.items():
    img2txt[value].append(key)

In [19]:
ranks = np.zeros(score_matrix_i2t.shape[0])
for index, score in enumerate(score_matrix_i2t):
    score = np.array(score)
    inds = np.flip(np.argsort(score))
    # Score
    rank = 1e20
    for i in img2txt[images_id_list[index]]:
        tmp = np.where(inds == i)[0] # search for text i location in score array tmp is single index number
        if len(tmp) != 0: 
            tmp = tmp[0]
        else:
            tmp = 1e20
        if tmp < rank:
            rank = tmp
    ranks[index] = rank

# Compute metrics
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)

In [20]:
tr1

51.82

In [21]:
tr5

76.68

# Cifar10

In [None]:
ds = load_dataset('cifar10')
train_ds = ds['train']
test_ds = ds['test']

In [None]:
print('Train dataset: ', train_ds)
print('Test dataset: ', test_ds)

In [None]:
label_dict_key = 'label'

In [None]:
id2label = {id:label.replace('_', ' ') for id, label in enumerate(train_ds.features[label_dict_key].names)}
label2id = {label:id for id,label in id2label.items()}
id2label

In [None]:
def train_transforms(examples):
    examples['pixel_values'] = [preprocess(image.convert("RGB")) for image in examples['img']]
    return examples

In [None]:
train_ds.set_transform(train_transforms)
test_ds.set_transform(train_transforms)

In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example[label_dict_key] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

batch_size = 60
train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=batch_size)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=batch_size)

In [None]:
label_prompted = []
for label_id, label_name in id2label.items():
    label_prompted.append('This is an image of ' + label_name)

tokenized = clip.tokenize(label_prompted).cuda(cuda_card)
encoded_label = model.encode_text(tokenized, get_all_token=False).detach()
encoded_label.shape

# Training

In [None]:
class CrossModalClassifier(nn.Module):
    """Linear layer to train on top of frozen features"""
    def __init__(self, embed_dim, n_token, num_class, num_heads=1):
        super(CrossModalClassifier, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_class = num_class
        self.n_token = n_token
        self.cross_att = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.linear = nn.Linear(embed_dim * n_token, num_class)

    def forward(self, image_emb, text_emb):
        x, _ = self.cross_att(text_emb, image_emb, image_emb)
        x = x.reshape((x.shape[0], -1))
        x = self.linear(x)
        return x

## CLIP Similarity matching

In [None]:
sample_count = 0
correct_count = 0
sum_score = torch.zeros(encoded_label.shape[0]).cuda(cuda_card)
count_class = torch.zeros(encoded_label.shape[0]).cuda(cuda_card)
score_per_image = [[]for i in range(encoded_label.shape[0])]

# with torch.backends.cuda.sdp_kernel(enable_flash=True) as disable:
for mini_batch in tqdm(test_dataloader):
    images = mini_batch['pixel_values'].cuda(cuda_card)
    labels = mini_batch['labels']
    images_encode = model.encode_image(images)
    norm = torch.norm(images_encode, dim=1, keepdim=True) @ torch.norm(encoded_label, dim=1, keepdim=True).T
    images_encode = images_encode.detach()
    
    max_similarity = torch.max((images_encode @ encoded_label.T) / norm, dim=1)
    predicted = max_similarity.indices
    max_score = max_similarity.values.detach()
    
    correct_count += (predicted.cpu() == labels).sum()
    sample_count += len(labels)
    
    for index, predict_class in enumerate(predicted):
        score_per_image[predict_class].append(max_score[index])
        sum_score[predict_class] += max_score[index]
        count_class[predict_class] += 1

In [None]:
avg_per_class = sum_score/count_class
print("Average Score: ", torch.mean(avg_per_class))
print(avg_per_class)

In [None]:
accuracy = correct_count / sample_count * 100
print('Accuracy score: ', accuracy)

## Train CrossModal Classifier layer

In [None]:
# crossAttLinearClassifier = CrossModalClassifier(encoded_label.shape[2], encoded_label.shape[1], encoded_label.shape[0]).cuda(cuda_card)
# optimizer = torch.optim.SGD(
#     crossAttLinearClassifier.parameters(),
#     lr=0.05,
#     momentum=0.9,
#     weight_decay=0, # we do not apply weight decay
# )
# ce_loss = nn.CrossEntropyLoss()

In [None]:
# batches_per_epoch = len(train_dataloader)
# n_epochs = 5
# # with torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:
# for epoch in range(n_epochs):
#     with tqdm(train_dataloader, unit="batch") as tepoch:
#         tepoch.set_description(f"Epoch {epoch}/{n_epochs}")

#         crossAttLinearClassifier.train()
#         for mini_batch in tepoch:
#             images = mini_batch['pixel_values'].cuda(cuda_card)
#             labels = mini_batch['labels'].cuda(cuda_card)
            
#             text_encode = []
#             for i in labels:
#                 text_encode.append(encoded_label[i])
#             text_encode = torch.stack(text_encode).cuda(cuda_card).float()
#             with torch.no_grad():
#                 images_encode = model.encode_image(images)
#                 images_encode = images_encode.float()
#             output = crossAttLinearClassifier(images_encode, text_encode)
#             predictions = output.argmax(dim=1, keepdim=True).squeeze()
#             loss = ce_loss(output, labels)
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
#             correct = (predictions == labels).sum().item()
#             accuracy = correct / batch_size * 100
#             tepoch.set_postfix(loss=loss.item(), accuracy=accuracy)

# torch.save(crossAttLinearClassifier.state_dict(), './weight/crossAttention')

## Fine tune with CrossModal Output

In [None]:
# crossAttLinearClassifier = CrossModalClassifier(encoded_label.shape[2], encoded_label.shape[1], encoded_label.shape[0]).cuda(cuda_card)
# crossAttLinearClassifier.load_state_dict(torch.load('./weight/crossAttention'))

# # vitClassifier = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
# # vitClassifier.classifier = nn.Linear(768, 10)
# # vitClassifier.cuda(cuda_card)

# vit = timm.create_model('vit_small_patch16_224.dino', pretrained=True)
# vit.cuda(cuda_card)
# vitClassifier = nn.Linear(384, 10).cuda(cuda_card)
# # vitClassifier = nn.Sequential(vit, vitClassifier)

# print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in vit.parameters()]):,}")

# optimizer = torch.optim.SGD(
#     vitClassifier.parameters(),
#     lr=0.05,
#     momentum=0.9,
#     weight_decay=0, # we do not apply weight decay
# )
# ce_loss = nn.CrossEntropyLoss()
# softmax = nn.Softmax(dim=1)

In [None]:
# batches_per_epoch = len(train_dataloader)
# n_epochs = 5
# # with torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:
# crossAttLinearClassifier.eval()
# model.eval()
# for epoch in range(n_epochs):
#     with tqdm(train_dataloader, unit="batch") as tepoch:
#         tepoch.set_description(f"Epoch {epoch}/{n_epochs}")
#         vit.eval()
#         vitClassifier.train()
#         for mini_batch in tepoch:
#             images = mini_batch['pixel_values'].cuda(cuda_card)
#             labels = mini_batch['labels'].cuda(cuda_card)
#             text_encode = []
#             for i in labels:
#                 text_encode.append(encoded_label[i])
#             text_encode = torch.stack(text_encode).cuda(cuda_card).float().detach()
            
#             with torch.no_grad():
#                 images_encode = model.encode_image(images)
#                 images_encode = images_encode.float().detach()
                
#             target = softmax(crossAttLinearClassifier(images_encode, text_encode).detach())
#             # output = softmax(vitClassifier(images))
            
#             vit_output = vit(images).detach()
#             output = softmax(vitClassifier(vit_output))
            
#             predictions = output.argmax(dim=1, keepdim=True).squeeze()
            
#             loss = ce_loss(output, target)
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
            
#             correct = (predictions == labels).sum().item()
#             accuracy = correct / batch_size * 100
#             tepoch.set_postfix(loss=loss.item(), accuracy=accuracy)
#         torch.save(vitClassifier.state_dict(), './weight/vitClassifier{}'.format(epoch))
            


In [None]:
# sample_count = 0
# correct_count = 0

# vitClassifier.eval()
# vit.eval()
# for mini_batch in tqdm(test_dataloader):
#     images = mini_batch['pixel_values'].cuda(cuda_card)
#     labels = mini_batch['labels'].cuda(cuda_card)

#     vit_output = vit(images).detach()
#     output = softmax(vitClassifier(vit_output).detach())
#     predictions = output.argmax(dim=1, keepdim=True).squeeze()
    
#     correct_count += (predictions == labels).sum()
#     sample_count += len(labels)

# accuracy = correct_count / sample_count * 100
# print('Accuracy score: ', accuracy)

In [None]:
vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
# for param in vit.parameters():
#     param.requires_grad = False
vit.classifier = nn.Linear(768, 10)
vit.classifier.weight.requires_grad = True
vit.classifier.bias.requires_grad = True
vit.cuda(cuda_card)

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in vit.parameters()]):,}")

optimizer = torch.optim.Adam(
    vit.parameters(),
    lr=0.05,
    # momentum=0.9,
    weight_decay=0, # we do not apply weight decay
)
ce_loss = nn.CrossEntropyLoss()
softmax = nn.Softmax(dim=1)

In [None]:
batches_per_epoch = len(train_dataloader)
n_epochs = 10
# with torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:
vit.train()
for epoch in range(n_epochs):
    with tqdm(train_dataloader, unit="batch") as tepoch:
        tepoch.set_description(f"Epoch {epoch+1}/{n_epochs}")
        for mini_batch in tepoch:
            images = mini_batch['pixel_values'].cuda(cuda_card)
            labels = mini_batch['labels'].cuda(cuda_card)
            
            output = softmax(vit(images).logits)
            
            predictions = output.argmax(dim=1, keepdim=True).squeeze()
            
            loss = ce_loss(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            correct = (predictions == labels).sum().item()
            accuracy = correct / batch_size * 100
            tepoch.set_postfix(loss=loss.item(), accuracy=accuracy)
        torch.save(vit.state_dict(), './weight_supervised/vitBaseClassifierNoFreeze{}'.format(epoch))

In [None]:
sample_count = 0
correct_count = 0

vit.eval()
for mini_batch in tqdm(test_dataloader):
    images = mini_batch['pixel_values'].cuda(cuda_card)
    labels = mini_batch['labels'].cuda(cuda_card)

    output = softmax(vit(images).logits.detach())
    predictions = output.argmax(dim=1, keepdim=True).squeeze()
    
    correct_count += (predictions == labels).sum()
    sample_count += len(labels)

accuracy = correct_count / sample_count * 100
print('Accuracy score: ', accuracy)


# Pototypical Test

In [None]:
# prototypes = torch.ones(encoded_label.shape, dtype=torch.float16).cuda(cuda_card)
# check_class = torch.zeros(encoded_label.shape[0])

In [None]:
# prototype_weight = 0.85
# for mini_batch in tqdm(train_dataloader):
#     images = mini_batch['pixel_values'].cuda(cuda_card)
#     labels = mini_batch['labels']
#     images_encode = model.encode_image(images)

#     for index, image_encode in enumerate(images_encode.detach()):
#         class_no = labels[index]
#         if check_class[class_no] == 0:
#             prototypes[class_no] = image_encode
#             check_class[class_no] = 1
#         else:
#             prototypes[class_no] = (prototypes[class_no] * prototype_weight) + (image_encode * (1 - prototype_weight))

In [None]:
# prototypes

In [None]:
# sample_count = 0
# correct_count = 0
# for mini_batch in tqdm(test_dataloader):
#     images = mini_batch['pixel_values'].cuda(cuda_card)
#     labels = mini_batch['labels'].cuda(cuda_card)
#     images_encode = model.encode_image(images)
    
#     predicted = torch.argmax(images_encode @ prototypes.T, dim=1)
#     correct_count += (predicted == labels).sum()
#     sample_count += len(labels)

In [None]:
# accuracy = correct_count / sample_count * 100
# print('Accuracy score: ', accuracy)