## Zero Shot Learning

In [1]:
from __future__ import print_function 
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Subset
import matplotlib.pyplot as plt
from random import shuffle
import pandas as pd
import pickle
import time
import math
import os
import copy
import sys
from imagenetv2_pytorch import ImageNetV2Dataset
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine
from pytorch_metric_learning import losses
import wandb

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

PyTorch Version:  1.8.1
Torchvision Version:  0.9.1


In [2]:
def initialize_model(model_name, embedding_dim, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if "resnet" in model_name:
        """ Resnet18
        """
        if model_name == "beefyresnet":
            num_layers = 3
        elif model_name == "verybeefyresnet":
            num_layers = 5
        else:
            num_layers = 1

        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features

        if num_layers > 1:
            model_ft.fc = nn.Linear(num_ftrs, num_ftrs)
            for _ in range(num_layers - 2):
                model_ft = nn.Sequential(
                    model_ft,
                    nn.ReLU(),
                    nn.Linear(num_ftrs, num_ftrs)
                )
            model_ft = nn.Sequential(
                    model_ft,
                    nn.ReLU(),
                    nn.Linear(num_ftrs, embedding_dim)
                )
        else:
            model_ft.fc = nn.Linear(num_ftrs, embedding_dim)
            
        input_size = 224

    else:
        print("Invalid model name, exiting...")
        sys.exit(0)
    
    return model_ft, input_size

In [3]:
dst = ["imagenetv2", "cifar10"]
losses=["cliplossv2", "cliploss", "cosineloss", "xeloss"]
embeddings=["onehot", "label", "wiki", "clip"]
networks = ["resnet", "beefyresnet", "verybeefyresnet"]

In [4]:
def closest_labelembedding(image_embeddings, label_embeddings):
    return (image_embeddings @ label_embeddings.T).argmax(-1)


def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [5]:
def is_a(wnid, group_wnids, imagenet_ancestors):
    r = False
    for g_wnid in group_wnids:
        if g_wnid in imagenet_ancestors[wnid]:
            r = True
    return r

def sparse2coarse(targets):
    """Convert Pytorch CIFAR100 sparse targets to coarse targets.
    Usage:
        trainset = torchvision.datasets.CIFAR100(path)
        trainset.targets = sparse2coarse(trainset.targets)
    """
    coarse_labels = np.array([ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,  
                               3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
                               6, 11,  5, 10,  7,  6, 13, 15,  3, 15,  
                               0, 11,  1, 10, 12, 14, 16,  9, 11,  5, 
                               5, 19,  8,  8, 15, 13, 14, 17, 18, 10, 
                               16, 4, 17,  4,  2,  0, 17,  4, 18, 17, 
                               10, 3,  2, 12, 12, 16, 12,  1,  9, 19,  
                               2, 10,  0,  1, 16, 12,  9, 13, 15, 13, 
                              16, 19,  2,  4,  6, 19,  5,  5,  8, 19, 
                              18,  1,  2, 15,  6,  0, 17,  8, 14, 13])
    return coarse_labels[targets]

# Data augmentation and normalization for training, just normalization for validation
def create_dataloaders(dataset_name, data_transforms, input_size, batch_size):
    
    dataloaders_dict = {}

    print("Initializing Datasets and Dataloaders...")
    if dataset_name == "imagenetv2" or dataset_name == "imagenetv2coarse":
        # Create training and validation datasets
        train_dataset = ImageNetV2Dataset(transform=data_transforms['train'])
        test_dataset = ImageNetV2Dataset(transform=data_transforms['val'])

        train_test_splits_file = 'split_indices.pkl'

        if os.path.exists(train_test_splits_file):
            indices_split = pickle.load(open(train_test_splits_file, 'rb'))
        else:
            index_to_class = {idx: cl for idx, (_, cl) in enumerate(train_dataset)}
            class_to_index = {idx: [] for idx in range(1000)}
            for idx, cl in index_to_class.items():
                class_to_index[cl].append(idx)

            indices_split = {'train': [], 'val': [], 'test': []}
            for cl in class_to_index:
                shuffle(class_to_index[cl])
                indices_split['train'].extend(class_to_index[cl][:int(0.7 * len(class_to_index[cl]))])
                indices_split['val'].extend(class_to_index[cl][int(0.7 * len(class_to_index[cl])):int(0.9 * len(class_to_index[cl]))])
                indices_split['test'].extend(class_to_index[cl][int(0.9 * len(class_to_index[cl])):])
            
            pickle.dump(indices_split, open(train_test_splits_file, 'wb'))
            

        # Create training and validation dataloaders
        dataloaders_dict = {x: torch.utils.data.DataLoader(Subset(test_dataset,indices_split[x]), batch_size=batch_size, shuffle=True, num_workers=4) for x in ['val', 'test']}
        dataloaders_dict['train'] = torch.utils.data.DataLoader(Subset(train_dataset, indices_split['train']), batch_size=batch_size, shuffle=True, num_workers=4)

    else:
        if dataset_name == "imagenet":
            train_dataset = torchvision.datasets.ImageNet(root='./data', split='train', download=True, transform=data_transforms['train'])
            test_dataset = torchvision.datasets.ImageNet(root='./data', split='val', download=True, transform=data_transforms['val'])
        elif dataset_name == "cifar10":
            train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transforms['train'])
            test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['val'])
        elif dataset_name == "cifar100":
            train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=data_transforms['train'])
            test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=data_transforms['val'])
        elif dataset_name == "cifar100super":
            train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=data_transforms['train'])
            train_dataset.targets = sparse2coarse(train_dataset.targets)
            test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=data_transforms['val'])
            test_dataset.targets = sparse2coarse(test_dataset.targets)
        else:
            print("Error dataset not found")
            sys.exit(0)

        dataloaders_dict['train'] = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        dataloaders_dict['val'] = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    return dataloaders_dict

In [6]:
def generate_label_embeddings(embedding_type, embedding_dim, dataset_name):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if embedding_type == "onehot":
        # One-hot text embeddings
        label_embeddings = torch.from_numpy(np.eye(embedding_dim, dtype=np.float32)).to(device)
    else:
        model = SentenceTransformer('paraphrase-distilroberta-base-v1')
        # Load labeltext
        
        if dataset_name == "imagenet" or dataset_name == "imagenetv2" or dataset_name == "imagenetv2coarse":
            classidx_to_label = pickle.load(open('classidx_to_label.pkl', 'rb'))
        elif dataset_name == "cifar10":
            classidx_to_label = pickle.load(open('./data/cifar-10-batches-py/batches.meta', 'rb'))['label_names']
        elif dataset_name == "cifar100":
            classidx_to_label = pickle.load(open('./data/cifar-100-python/meta', 'rb'))['fine_label_names']
        elif dataset_name == "cifar100super":
            classidx_to_label = pickle.load(open('./data/cifar-100-python/meta', 'rb'))

            coarse = classidx_to_label['coarse_label_names']
            fine = classidx_to_label['fine_label_names']
            classidx_to_label = []

            for i in range(len(coarse)):
                classidx_to_label.append('')

            for i in range(len(fine)):
                if classidx_to_label[sparse2coarse(i)] == '':
                    classidx_to_label[sparse2coarse(i)] = str(fine[i])
                else:
                    classidx_to_label[sparse2coarse(i)] += " or " + str(fine[i])
        else:
            print("ERROR: Dataset without labels")
            return None
        
        if embedding_type == "label":
            # Text embeddings based on class label text
            if dataset_name == "imagenet" or dataset_name == "imagenetv2" or dataset_name == "imagenetv2coarse":
                labels = list(classidx_to_label.values())
            else:
                labels = list(classidx_to_label)
                
        elif embedding_type == 'wiki':
            if dataset_name == "imagenet" or dataset_name == "imagenetv2" or dataset_name == "imagenetv2coarse":
                # Load wikitext
                wiki_path = 'ImageNet-Wiki_dataset/class_article_text_descriptions/class_article_text_descriptions_trainval.pkl'
                wiki_articles = pickle.load(open(wiki_path, 'rb'))
                wiki_label_map = pd.read_csv('LOC_synset_mapping.txt', sep=': ', names=['wnid', 'labels'])


                wiki_labels = {}
                for i in classidx_to_label.keys():
                    wiki_labels[wiki_label_map.iloc[i]['wnid']] = classidx_to_label[i]

                for i in wiki_articles.keys():
                    try:
                        wiki_labels[wiki_articles[i]['wnid']] = wiki_articles[i]['articles'][0]
                    except:
                        pass
                labels = list(wiki_labels.values())
            else:
                return None
        elif embedding_type == 'clip':
            # Text embeddings based on class label text
            clip_labels={}
            if dataset_name == "imagenet" or dataset_name == "imagenetv2" or dataset_name == "imagenetv2coarse":
                for i in classidx_to_label.keys():
                    clip_labels[i] = 'A photo of ' + str(classidx_to_label[i].split(",")[0].rstrip().lstrip())
            else:
                for i in range(len(classidx_to_label)):
                    clip_labels[i] = 'A photo of ' + str(classidx_to_label[i].split(",")[0].rstrip().lstrip())
            labels = list(clip_labels.values())

        label_embeddings = torch.from_numpy(model.encode(labels)).to(device)
        label_embeddings = label_embeddings / label_embeddings.norm(dim=-1, keepdim=True)

    return label_embeddings

In [7]:
def morph_labels_imagenetv2(inputlabels):
    
    ancestor_label_file = 'imagenetv2_ancestor.pkl'
    
    if os.path.exists(ancestor_label_file):
        labels = pickle.load(open(ancestor_label_file, 'rb'))
    else:
        classidx_to_label = pickle.load(open('classidx_to_label.pkl', 'rb'))
        imagenet_ancestors = pickle.load(open('imagenet_wordnet_ancestor_categories.pkl', 'rb'))

        wnid_animals = ['n00015388']
        wnid_plants = ['n00017222', 'n07707451', 'n13134947']

        label_map = pd.read_csv('LOC_synset_mapping.txt', sep=': ', names=['wnid', 'labels'])

        ancestor_labels = {}
        for i in classidx_to_label.keys():
            wnid = label_map.iloc[i]['wnid']

            if is_a(wnid, wnid_animals, imagenet_ancestors):
                ancestor_labels[wnid] = 0
            elif is_a(wnid, wnid_plants, imagenet_ancestors):
                ancestor_labels[wnid] = 1
            else:
                ancestor_labels[wnid] = 2
        
        labels = list(ancestor_labels.values())

        pickle.dump(labels, open(ancestor_label_file, 'wb'))
    
    result = []
    for i in inputlabels:
        result.append(labels[i])
    return result

def morph_labels_cifar100(inputlabels):
    result = []
    for i in inputlabels:
        result.append(sparse2coarse(i))
    return result

In [8]:
def zsl_test(model, testloader, label_embeddings=None, embedding_to_label=None, morph_labels=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    running_corrects = 0
    total_num = 0
    for inputs, labels in testloader:
        inputs = inputs.to(device)

        outputs = model(inputs)      
        outputs = outputs / outputs.norm(dim=-1, keepdim=True)

        if embedding_to_label is not None and label_embeddings is not None:
            preds = embedding_to_label(outputs, label_embeddings)
        else:
            _, preds = torch.max(outputs, 1)

        # statistics
        if morph_labels is None:
            running_corrects += torch.sum(preds == labels.data)
        else:
            labels = morph_labels(labels.data)
            preds = morph_labels(preds)
            running_corrects += torch.sum(torch.Tensor(preds) == torch.Tensor(labels))
        total_num += inputs.size(0)

    accuracy = running_corrects.double() / total_num

    return accuracy

In [9]:
def evaluate_zsl(modelpath, loss, embedding_type, load_model_name, dataset_trained, csvfile, batch_size = 256):
    
    if dataset_trained == "cifar10" or dataset_trained=="imagenetv2":
        print("Trained on cifar10, imagenetv2 ignoring")
        return
    
    if loss == "cliplossv2":
        print("Cliploss v2")
        return
    
#     if loss == "xeloss":
#         print("Xeloss")
#         return
    
    if load_model_name == "beefyresnet":
        print("Beefy")
        return
    
    if (loss != "xeloss" and embedding_type == "onehot") or (loss == "xeloss" and embedding_type != "onehot"):
        print("Not XELoss and onehot")
        return
    
    
#     print(modelpath)
#     return
    
    dataset_label_sizes = {'imagenet':1000, 'imagenetv2': 1000, 'cifar10': 10, 'cifar100': 100, 'imagenetv2coarse': 1000,}

    mean = {'cifar10': (0.4914, 0.4822, 0.4465), 'cifar100': (0.5071, 0.4867, 0.4408), 'cifar100super': (0.2675, 0.2565, 0.2761), 'imagenet':(0.485, 0.456, 0.406), 'imagenetv2': (0.485, 0.456, 0.406), 'imagenetv2coarse': (0.485, 0.456, 0.406)}
    std  = {'cifar10': (0.2023, 0.1994, 0.2010), 'cifar100': (0.2675, 0.2565, 0.2761), 'cifar100super': (0.2675, 0.2565, 0.2761), 'imagenet':(0.229, 0.224, 0.225), 'imagenetv2': (0.229, 0.224, 0.225), 'imagenetv2coarse': (0.229, 0.224, 0.225)}
    
    if embedding_type == "onehot":
        embedding_dim = dataset_label_sizes[dataset_trained]
    else:
        embedding_dim = 768
        
    zslmodel, input_size = initialize_model(load_model_name, embedding_dim, False, use_pretrained=True)
    zslmodel.load_state_dict(torch.load(modelpath))
    
    # Detect if we have a GPU available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using GPU? {torch.cuda.is_available()}")
    zslmodel = zslmodel.to(device)
    zslmodel.eval()
    
#     datasets = ["imagenetv2", "cifar10", "cifar100", "cifar100super"]
    datasets = ["cifar100"]

    for dataset_name in datasets:
#         if dataset_name = dataset_trained:
#             continue
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean[dataset_name], std[dataset_name])
            ]),
            'val': transforms.Compose([
                transforms.Resize(input_size),
                transforms.CenterCrop(input_size),
                transforms.ToTensor(),
                transforms.Normalize(mean[dataset_name], std[dataset_name])
            ]),
        }
        
        for embed_t in ["clip", "label", "onehot"]:
#         for embed_t in ["onehot"]:
            if loss == "xeloss" and embed_t != "onehot":
                continue
            if embed_t == "onehot" and loss != "xeloss":
                continue
            dataloader_dict = create_dataloaders(dataset_name, data_transforms, input_size, batch_size)

            label_embeddings = generate_label_embeddings(embed_t, embedding_dim, dataset_name)

            if label_embeddings is None:
                continue

            accuracy = zsl_test(zslmodel, dataloader_dict['val'], label_embeddings=label_embeddings, embedding_to_label=closest_labelembedding, morph_labels=morph_labels_cifar100)
            
            with open(csvfile, 'a') as f:
                line = str(modelpath) + ", " + str(loss) + ", " + str(embedding_type)  + ", " + str(load_model_name) + ", " + str(dataset_trained) + ", " + str(dataset_name) + ", " + str(embed_t) + ", " + str(accuracy.cpu().numpy()) + "\n"
                f.writelines(line)
            print(dataset_name, embed_t, accuracy.cpu().numpy())
        
    del zslmodel


In [10]:
files = os.listdir('./models/')

In [11]:
# csvfile = 'zslresults2.csv'
csvfile = 'zsl_cifar100coarse.csv'
with open(csvfile, 'w') as f:
    f.writelines("modelpath, loss, train_embedding, model_type, train_dataset, zsldataset, zslembedding, accuracy\n")
for i in files:
    props = i.split("_")
    loss, embed, model_name, dataset = props[:4]
    modelpath = './models/' + str(i)
    print(i)
#     try:
    zslmodel = evaluate_zsl(modelpath, loss, embed, model_name, dataset, csvfile)
#     except:
#         print("Model failing to load")
    print()
    print()

cliplossv2_wiki_verybeefyresnet_imagenetv2_256_adam_0.3355.mdl
Trained on cifar10, imagenetv2 ignoring


cosineloss_label_verybeefyresnet_cifar10_256_adam.mdl
Trained on cifar10, imagenetv2 ignoring


cliploss_clip_verybeefyresnet_imagenetv2cifar100_256_adam_final.mdl
Using GPU? True
Initializing Datasets and Dataloaders...
Files already downloaded and verified
Files already downloaded and verified
cifar100 clip 0.7491
Initializing Datasets and Dataloaders...
Files already downloaded and verified
Files already downloaded and verified
cifar100 label 0.7355


cosineloss_clip_beefyresnet_cifar10_256_adam.mdl
Trained on cifar10, imagenetv2 ignoring


cliploss_onehot_resnet_imagenetv2_256_adam_0.3915.mdl
Trained on cifar10, imagenetv2 ignoring


cliploss_onehot_resnet_imagenetv2_256_adam.mdl
Trained on cifar10, imagenetv2 ignoring


cliploss_label_resnet_imagenetv2_256_adam_0.1975.mdl
Trained on cifar10, imagenetv2 ignoring


cosineloss_wiki_resnet_imagenetv2_256_adam_0.166.mdl
Trained on c