In [32]:
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, dataset, template):
        self.dataset = dataset
        self.classes = dataset.classes
        
        prompt = []
        
        for data in tqdm(dataset):
            prompt.append(template.format(self.classes[data[1]])) 
        
        self.prompt = clip.tokenize(prompt)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image = self.dataset[idx][0]
        label = self.dataset[idx][1]
        text = self.prompt[idx]
        
        return image, label, text

In [33]:
def zeroshot_classifier(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates]
            texts = clip.tokenize(texts).cuda()
            
            class_embeddings = model.encode_text(texts)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            
            zeroshot_weights.append(class_embedding)
            
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
        
    return zeroshot_weights

In [40]:
import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
train = CIFAR10('./cifar10_data', download=True, train=True, transform=preprocess)
test = CIFAR10('./cifar10_data', download=True, train=False, transform=preprocess)

prompt = ['this is a photo of {}']

def get_features(dataset):
    all_features = []
    all_labels = []
    
    weights = zeroshot_classifier(train.classes, prompt)

    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            features /= features.norm(dim=-1, keepdim=True)
            
            all_features.append(features @ weights)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(penalty='l2', random_state=0, C=0.1, max_iter=3000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

Files already downloaded and verified
Files already downloaded and verified


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 171.63it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.83it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 189.76it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [00:11<00:00,  8.82it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
 This problem is unconstrained.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =          110     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  1.15129D+05    |proj g|=  2.95666D+02

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
  110     49     53      1     0     0   3.518D-01   8.639D+04
  F =   86393.929816777047     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             
Accuracy = 88.920


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.6s finished
