## Train Awareness with ViT-B-16 encoder on CIFAR-10

In [7]:
import os
import clip
import yaml
import torch
import pathlib
import numpy as np
import torchvision
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from awareness import awareness
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.models.feature_extraction import create_feature_extractor

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cuda


In [3]:
ENCODER_PATH = './checkpoints/vit-b-16_cifar-10/experiments/exp1_tl0.0082_ta0.9526/weights/best.pt'
WINDOW_SIZE = 1
#BATCH_SIZE = 1024
NUM_CLASSES = 10
BATCH_SIZE = 32
EPOCHS = 100
IMG_SIZE = 224
DYNAMIC_RAY = True
MODEL_NAME = 'Awareness+ViT-B-16'
DATASET_NAME = 'CIFAR-10'

In [4]:
encoder_model = torch.load(ENCODER_PATH)
encoder_model.eval().to(device)

feature_extractor = create_feature_extractor(encoder_model, return_nodes=['encoder'])

awareness_model = awareness.Awareness(learnable=True, dynamic_ray=True)
awareness_model.to(device)

print("Encoder model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in encoder_model.parameters()]):,}")
print("Awareness model parameters:", f"{int(np.sum([int(np.prod(p.shape)) for p in awareness_model.parameters()])):,}")

Encoder model parameters: 86,567,656
Awareness model parameters: 0


In [5]:
preprocess = transforms.Compose(
    [transforms.Resize((IMG_SIZE,IMG_SIZE)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

cifar10_train = CIFAR10(os.path.expanduser("~/.cache"), train=True, transform=preprocess, download=True)
cifar10_test = CIFAR10(os.path.expanduser("~/.cache"), train=False, transform=preprocess, download=True)

test_loader = DataLoader(
    cifar10_test,
    batch_size=int(BATCH_SIZE),
    shuffle=True
)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
save_path = f'./checkpoints/{MODEL_NAME}_{DATASET_NAME}'.lower()

config = {
    'model_name': MODEL_NAME,
    'encoder_path': ENCODER_PATH,
    'dataset': DATASET_NAME,
    'batch_size': BATCH_SIZE,
    'epochs': EPOCHS,
    'img_size': IMG_SIZE,
    'num_classes': NUM_CLASSES
}

results_data = {
    'epoch': [], 
    'ref_instances': [],
    'train_acc': [], 
    'test_acc': []
}

res_df = pd.DataFrame(results_data)

if(not os.path.exists(f'{save_path}/')):
    os.makedirs(f'{save_path}')

with open(f'{save_path}/config.yaml', 'w') as yaml_file:
    yaml.dump(config, yaml_file, default_flow_style=False)

best_accuracy = 0.0
best_train_acc = 0.0
    
for epoch in range(EPOCHS):
    
    train_loader = DataLoader(
        cifar10_train,
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    
    awareness_model.__init__(learnable=True, dynamic_ray=True)
    
    encoder_model.eval()
    awareness_model.eval()
    
    train_loaders = [train_loader]
    
    with torch.no_grad(): 
    
        for train_loader in train_loaders:

            for i, (images, labels) in enumerate(train_loader):
                
                train_correct_preds_batches = []
                test_correct_preds_batches = []

                train_count = 0
                test_count = 0

                if torch.cuda.is_available():
                    images = Variable(images.cuda())
                    labels = Variable(labels.cuda())

                #features = encoder_model(images).float()
                features = torch.mean(feature_extractor(images)['encoder'].float(), 1)
                
                preds = awareness_model(torch.unsqueeze(features,1), set_labels=labels, update_ref_insts=True)

                train_correct_preds_batch = np.sum(preds.cpu().numpy() == labels.cpu().numpy())
                train_correct_preds_batches.append(train_correct_preds_batch)
                train_count = train_count+len(images)

                references = awareness_model.awareness.ref_insts
                references_labels = awareness_model.awareness.ref_insts_labels

                n_ref_insts = len(references)

                train_acc = round(np.sum(train_correct_preds_batches)/train_count, 4)

                print(f'Train ({n_ref_insts} refs) --> {train_acc}')
        
        print('####')
        for i, (images, labels) in enumerate(test_loader):

            if torch.cuda.is_available():
                images = Variable(images.cuda())
                labels = Variable(labels.cuda())

            #features = encoder_model(images).float()
            features = torch.mean(feature_extractor(images)['encoder'].float(), 1)
            
            preds = awareness_model(torch.unsqueeze(features,1))

            test_correct_preds_batch = np.sum(preds.cpu().numpy() == labels.cpu().numpy())
            test_correct_preds_batches.append(test_correct_preds_batch)
            test_count = test_count+len(images)

            test_acc = round(np.sum(test_correct_preds_batches)/test_count, 4)

            print(f'Test ({n_ref_insts} refs) --> {test_acc}')

        print(f'Epoch {epoch+1}, Reference instances (N): {n_ref_insts}, Train accuracy: {train_acc}, Test accuracy: {test_acc}')

        results_data = {
            'epoch': epoch+1, 
            'ref_instances': n_ref_insts, 
            'train_acc': train_acc, 
            'test_acc': test_acc
        }

        if(not os.path.exists(f'{save_path}/weights')):
            os.makedirs(f'{save_path}/weights')
    
        res_df.loc[len(res_df)] = results_data
        res_df.to_csv(f'{save_path}/results.csv', index=False)
    
        torch.save(model, f'{save_path}/weights/last.pt')
    
        if(test_acc > best_test_acc):
            torch.save(model, f'{save_path}/weights/best.pt')
            best_test_acc = test_acc
    
            print(f'Saved checkpoint related to better accuracy score: {best_test_acc}')

Train (27 refs) --> 1.0
Train (48 refs) --> 1.0
Train (66 refs) --> 0.9375
Train (79 refs) --> 0.8125
Train (94 refs) --> 0.875
Train (109 refs) --> 0.9688
Train (124 refs) --> 0.9375
Train (143 refs) --> 0.875
Train (159 refs) --> 0.9375
Train (176 refs) --> 0.9375
Train (193 refs) --> 0.9375
Train (204 refs) --> 0.9375
Train (222 refs) --> 0.9688
Train (236 refs) --> 0.9375
Train (248 refs) --> 0.9062
Train (260 refs) --> 0.9062
Train (274 refs) --> 0.9375
Train (290 refs) --> 0.9062
Train (304 refs) --> 0.9062
Train (316 refs) --> 0.8438
Train (332 refs) --> 0.8125
Train (349 refs) --> 0.9688
Train (354 refs) --> 0.9062
Train (369 refs) --> 0.9062
Train (382 refs) --> 0.9688
Train (393 refs) --> 0.9688
Train (410 refs) --> 0.9375
Train (419 refs) --> 0.9688
Train (426 refs) --> 0.8438
Train (438 refs) --> 0.8438
Train (444 refs) --> 0.9062
Train (455 refs) --> 0.9375
Train (466 refs) --> 0.9062
Train (473 refs) --> 0.9062
Train (478 refs) --> 0.9062
Train (489 refs) --> 0.9375
Train