In [1]:
!kill -9 1233966                  

In [None]:
!nvidia-smi

In [1]:
import numpy as np
import argparse
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100, CIFAR10

In [2]:
import clip

In [3]:
def zeroshot_classifier(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates]
            texts = clip.tokenize(texts).cuda()
            
            class_embeddings = model.encode_text(texts)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            
            zeroshot_weights.append(class_embedding)
            
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
        
    return zeroshot_weights

In [4]:
def accuracy(output, target, topk=(1,), ensemble=False):
    
    if ensemble:
        output = output.topk(1, 1, True, True)[0]
        output = output.reshape(100, 100)
    
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [5]:
def zeroshot_learning(model, dataset, dataloader, prompt, ensemble=False):
    with torch.no_grad():
        
        if ensemble == False:
            zeroshot_weights = zeroshot_classifier(dataset.classes, prompt)
        else:
            zeroshot_weights = [
                zeroshot_classifier(dataset.classes, [p]) for p in prompt
            ]
        
        top1, top5, n = 0., 0., 0.

        for i, (images, targets) in enumerate(tqdm(dataloader)):
            images, targets = images.cuda(), targets.cuda()
            
            
            images_features = model.encode_image(images)
            images_features /= images_features.norm(dim=-1, keepdim=True)
            
            # predict
            if ensemble == False:
                logits = 100. * images_features @ zeroshot_weights
            else:
#                 logits = 0
                
#                 for zeroshot_weight in zeroshot_weights:
#                     logits += 100. * images_features @ zeroshot_weight
                    
#                 logits /= len(zeroshot_weights)
                logits = []
                
                for zeroshot_weight in zeroshot_weights:
                    logits.append(100. * images_features @ zeroshot_weight)
                    
                logits = torch.stack(logits, dim=1)
            

            # measuer accuracy
            acc1, acc5 = accuracy(logits, targets, topk=(1, 5), ensemble=ensemble)
            top1 += acc1
            top5 += acc5
            n += images.size(0)

        top1 = (top1 / n) * 100
        top5 = (top5 / n) * 100

    if os.path.exists('./results.txt') == False:
        with open('./results.txt', 'w') as f:
            f.write('zeroshot learning results\n\n')
            f.write('---------- cifar10 ----------\n')
            for p in prompt:
                f.write(f'{p.format("[CLASS]")} \n')
                
            f.write(f'ensemble: {ensemble}   top1: {top1:.2f}    top5: {top5:.2f} \n')
    else:
        with open('./results.txt', 'a') as f:
            f.write('\n')
            f.write('---------- cifar10 ----------\n')
            for p in prompt:
                f.write(f'{p.format("[CLASS]")} \n')
                
            f.write(f'ensemble: {ensemble}   top1: {top1:.2f}    top5: {top5:.2f} \n')
        
        
    return top1, top5

In [6]:
model_name = 'ViT-L/14@336px'
batch_size = 100

available_model = clip.available_models()

print(available_model)

assert model_name in available_model

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [7]:
model, preprocess = clip.load(model_name)
model.cuda().eval()

input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print(preprocess)

Model parameters: 427,944,193
Input resolution: 336
Context length: 77
Vocab size: 49408
Compose(
    Resize(size=336, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(336, 336))
    <function _convert_image_to_rgb at 0x7f083b127ee0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)


In [8]:
prompt = [
    'a bad photo of a {}.',
    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]

In [9]:
prompt = [
    'This is a photo of a {}',
    'This is a photo of {}'
]

In [9]:
cifar10 = CIFAR10('./cifar10_data', train=False, transform=preprocess, download=False)
cifar10_dataloader = DataLoader(cifar10, batch_size=batch_size, shuffle=False)

top1, top5 = zeroshot_learning(model, cifar10, cifar10_dataloader, prompt, False)

print(f'cifar10 Top-1 accuracy: {top1:.2f}')
print(f'cifar10 Top-5 accuracy: {top5:.2f}')

100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 13.03it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [01:41<00:00,  1.02s/it]

cifar10 Top-1 accuracy: 95.17
cifar10 Top-5 accuracy: 99.66





In [10]:
cifar100 = CIFAR100('./cifar100_data', train=False, transform=preprocess, download=False)
cifar100_dataloader = DataLoader(cifar100, batch_size=batch_size, shuffle=False)

top1, top5 = zeroshot_learning(model, cifar100, cifar100_dataloader, prompt, False)

print(f'cifar100 Top-1 accuracy: {top1:.2f}')
print(f'cifar100 Top-5 accuracy: {top5:.2f}')

100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 33.08it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [01:41<00:00,  1.02s/it]

cifar100 Top-1 accuracy: 75.65
cifar100 Top-5 accuracy: 93.27



