## Train Awareness with CLIP encoder on CIFAR10

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import pathlib
import numpy as np
import os
import clip
from awareness import awareness

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cpu


In [3]:
WINDOW_SIZE = 1
BATCH_SIZE = 1024
EPOCHS = 100
DYNAMIC_RAY = True
SAVE_PATH = 'checkpoints/awareness-clip_cifar10'

In [5]:
clip_model, preprocess = clip.load("ViT-B/32")
clip_model.eval().to(device)

awareness_model = awareness.Awareness(learnable=True, dynamic_ray=True)
awareness_model.to(device)

print("Clip model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in clip_model.parameters()]):,}")
print("Awareness model parameters:", f"{int(np.sum([int(np.prod(p.shape)) for p in awareness_model.parameters()])):,}")

Clip model parameters: 151,277,313
Awareness model parameters: 0


In [6]:
cifar10_train = CIFAR10(os.path.expanduser("~/.cache"), train=True, transform=preprocess, download=True)
cifar10_test = CIFAR10(os.path.expanduser("~/.cache"), train=False, transform=preprocess, download=True)

test_loader = DataLoader(
    cifar10_test,
    batch_size=int(BATCH_SIZE),
    shuffle=True
)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
best_accuracy = 0.0
best_train_acc = 0.0
    
for epoch in range(EPOCHS):
    
    train_loader = DataLoader(
        cifar10_train,
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    
    awareness_model.__init__(learnable=True, dynamic_ray=True)
    
    clip_model.eval()
    awareness_model.eval()
    
    train_loaders = [train_loader]
    
    with torch.no_grad(): 
    
        for train_loader in train_loaders:

            for i, (images, labels) in enumerate(train_loader):

                train_correct_preds_batches = []
                test_correct_preds_batches = []

                train_count = 0
                test_count = 0

                if torch.cuda.is_available():
                    images = Variable(images.cuda())
                    labels = Variable(labels.cuda())

                features = clip_model.encode_image(images).float()
                preds = awareness_model(torch.unsqueeze(features,1), set_labels=labels, update_ref_insts=True)

                train_correct_preds_batch = np.sum(preds.cpu().numpy() == labels.cpu().numpy())
                train_correct_preds_batches.append(train_correct_preds_batch)
                train_count = train_count+len(images)

                references = awareness_model.awareness.ref_insts
                references_labels = awareness_model.awareness.ref_insts_labels

                n_ref_insts = len(references)

                train_acc = round(np.sum(train_correct_preds_batches)/train_count, 4)

                print(train_acc)
                
        for i, (images, labels) in enumerate(test_loader):

            if torch.cuda.is_available():
                images = Variable(images.cuda())
                labels = Variable(labels.cuda())

            features = clip_model.encode_image(images).float()
            preds = awareness_model(torch.unsqueeze(features,1))

            test_correct_preds_batch = np.sum(preds.cpu().numpy() == labels.cpu().numpy())
            test_correct_preds_batches.append(test_correct_preds_batch)
            test_count = test_count+len(images)

            test_batch_accuracy = round(np.sum(test_correct_preds_batch)/preds.size(0), 4)

            running_test_accuracy = round(np.sum(test_correct_preds_batches)/test_count, 4)

        test_acc = round(np.sum(test_correct_preds_batches)/test_count, 4)

        print(f'Epoch {epoch+1}, Reference instances (N): {n_ref_insts}, Train accuracy: {train_acc}, Test accuracy: {test_acc}')

        if test_accuracy > best_accuracy:

            folder_path = SAVE_PATH

            if(not os.path.exists(folder_path)):
                os.mkdirs(folder_path)

            torch.save(clip_model, f'./{folder_path}/clip.pt') 
            torch.save(awareness_model, f'{folder_path}/awareness.pt')

            best_accuracy = test_accuracy

            print(f'Saved checkpoint: epoch {epoch+1}, Reference instances (N): {n_ref_insts}, Train accuracy: {train_acc}, Test accuracy: {test_accuracy}')

0.9717


KeyboardInterrupt: 