In [None]:
import os
if os.getcwd().split(os.sep)[-1] == 'notebooks':
    os.chdir('..')
print(os.getcwd())

In [None]:
%matplotlib inline
import torch
#import CLIP.clip as clip
import clip

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression


import torch.nn.functional as F
import torch.nn as nn
import torchvision
from torchvision import transforms
import PIL
from omegaconf import OmegaConf
from tqdm import tqdm
import pdb

import torchvision 

from utils import general_utils as gu

In [None]:
device    = "cuda:2"
#clip_type = 'ViT-B/32'
clip_type = 'RN50'

model, preprocess = clip.load(clip_type, device, jit=False)



In [None]:


#cfg_file  = 'configs/waterbirds_generic.yaml'
#cfg_file  = 'configs/coco_generic.yaml'
#cfg_file  = 'configs/coco_device.yaml'
#cfg_file  = 'configs/planes.yaml'
cfg_file  = 'configs/food_subset_generic.yaml'
base_cfg  = OmegaConf.load('configs/base.yaml')
cfg       = OmegaConf.load(cfg_file)
args = OmegaConf.merge(base_cfg, cfg)


if args.DATA.DATASET == 'waterbirds':
    from datasets.waterbirds import Waterbirds as Dataset
    cfg.DATA.WATERBIRDS_DIR = 'waterbird_1.0_forest2water2'
    cfg.DATA.CONFOUNDING_FACTOR = 1.0
elif args.DATA.DATASET == 'coco_gender':
    from datasets.coco import COCOGender as Dataset
elif args.DATA.DATASET == 'coco_device':
    from datasets.coco_device import COCODevice as Dataset
elif args.DATA.DATASET == 'planes':
    from datasets.planes import Planes as Dataset
    args.DATA.BIAS_TYPE = 'bias_A'
elif args.DATA.DATASET == 'food':
    from datasets.food import Food as Dataset
elif args.DATA.DATASET == 'food_subset':
    from datasets.food import FoodSubset as Dataset
else:
    raise NotImplementedError
    
 

transform = None
train_dataset = Dataset(root='./data',
                  cfg=args,
                  transform=transform,
                  split='train')
val_dataset = Dataset(root='./data',
                  cfg=args,
                  transform=transform,
                  split='val')
test_dataset = Dataset(root='./data',
                  cfg=args,
                  transform=transform,
                  split='test')



In [None]:


train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.DATA.BATCH_SIZE,
                                               num_workers=args.DATA.NUM_WORKERS,
                                               shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.DATA.BATCH_SIZE,
                                             num_workers=args.DATA.NUM_WORKERS,
                                             shuffle=False)


test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=1,
                                             num_workers=args.DATA.NUM_WORKERS,
                                             shuffle=False)


In [None]:


def get_features(image_filenames, preprocess):
    all_features = []
    
    #with torch.no_grad():
    torch.set_grad_enabled(False)
    for image in tqdm(image_filenames, total=len(image_filenames)):
        image = preprocess(Image.open(image)).unsqueeze(0).to(device)
        features = model.encode_image(image)
        all_features.append(features)

    torch.set_grad_enabled(True)
    return torch.cat(all_features)#.cpu().numpy()


def group_accuracy(group_labels, class_labels, predictions, num_groups):
    accs = []
    for group in range(num_groups):
        indices = np.where(group_labels == group)[0]
        #print('Num samples of group {}: {}'.format(group, len(indices)))
        group_preds = predictions[indices]
        group_gt = class_labels[indices]
        acc = np.mean((group_preds == group_gt).astype(float)) * 100.
        accs.append(acc)
    group_accs = np.array(accs)
    gs = [np.round(g,2) for g in group_accs]
    print('Group accs: {}'.format(gs))
    return group_accs

def class_accuracy(labels, predictions, num_classes):
    accs = []
    for cls in range(num_classes):
        indices = np.where(labels == cls)[0]
        class_preds = predictions[indices]
        class_gt = labels[indices]
        acc = np.mean((class_gt == class_preds).astype(float)) * 100.
        accs.append(acc)
    accs = [np.round(a, 2) for a in accs]
    print('Class accs: {}'.format(accs))
    print('Mean class acc: {}'.format(np.array(accs).mean()))
    return np.array(accs)
        

def predict_zero_shot(model, all_image_features, text_features):
    torch.set_grad_enabled(False)
    preds = []
    for i in tqdm(range(all_image_features.shape[0])):
        image_features = all_image_features[i]

        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        pred = np.argmax(probs.cpu().numpy())
        preds.append(pred)
        
    torch.set_grad_enabled(True)
    return np.array(preds)


def evaluate(preds, labels, group_labels=None, num_groups=None, num_classes=2):
    overall_acc = np.mean((preds == labels).astype(float))
    print('Per instance acc: {}'.format(np.round(overall_acc*100., 2)))
    class_accuracy(labels, preds, num_classes)

    if group_labels is not None:
        group_accuracy(group_labels, labels, preds, num_groups)


def get_filenames_labels(dataset, dataset_type):
    if dataset_type == 'waterbirds':
        filenames = dataset.image_filenames
        labels = dataset.labels_split
    elif dataset_type == 'planes':
        filenames = dataset.filenames
        labels = dataset.labels
    elif dataset_type == 'food_subset':
        tuples = dataset.imgs
        filenames = np.array([t[0] for t in tuples])
        labels    = np.array([t[1] for t in tuples])
    else:
        filenames = dataset.filenames
        labels = dataset.labels
    return filenames, np.array(labels)


def get_preprocess_no_crop(preprocess):
    preprocess_no_crop = []
    for t in preprocess.transforms:
        if type(t) == torchvision.transforms.transforms.Resize:
            preprocess_no_crop.append(transforms.Resize((224,224), interpolation=PIL.Image.BICUBIC))
        else:
            if type(t) != torchvision.transforms.transforms.CenterCrop:
                preprocess_no_crop.append(t)
    preprocess = transforms.Compose(preprocess_no_crop)
    return preprocess

In [None]:
preprocess = get_preprocess_no_crop(preprocess)

In [None]:


train_filenames, train_labels = get_filenames_labels(train_dataset, args.DATA.DATASET)
val_filenames, val_labels     = get_filenames_labels(val_dataset, args.DATA.DATASET)

train_features = get_features(train_filenames, preprocess)
val_features = get_features(val_filenames, preprocess)


In [None]:
if args.DATA.DATASET == 'waterbirds':
    train_group_labels = train_dataset.group_labels_split
    val_group_labels   = val_dataset.group_labels_split
    test_group_labels  = test_dataset.group_labels_split
    num_groups = 4
    num_classes = 2
elif args.DATA.DATASET == 'planes':
    train_group_labels = train_dataset.groups
    val_group_labels   = val_dataset.groups
    test_group_labels  = test_dataset.groups
    num_groups = 4
    num_classes = 2
elif args.DATA.DATASET == 'food_subset':
    train_group_labels = None
    val_group_labels   = None
    test_group_labels  = None
    num_groups = None
    num_classes = 5
else:
    train_group_labels, val_group_labels, test_group_labels = None, None, None
    num_groups = None

In [None]:

for C in [0.001, 0.1, 0.3, 0.5, 1, 5, 10, 100, 200, 500, 1000, 2000]:
    print('*************************************************')
    print('C: {}'.format(C))
    classifier = LogisticRegression(random_state=0, C=C, max_iter=5000, verbose=1)
    classifier.fit(train_features, train_labels)
    print()
    print('**** Train performance ****')
    preds = classifier.predict(train_features)
    evaluate(preds, train_labels, train_group_labels, num_groups=num_groups, num_classes=num_classes)

    print('**** Val performance ***')
    preds = classifier.predict(val_features)
    evaluate(preds, val_labels, val_group_labels, num_groups=num_groups, num_classes=num_classes)
    
    print()

In [None]:
test_filenames, test_labels = get_filenames_labels(test_dataset, args.DATA.DATASET)
test_features = get_features(test_filenames, preprocess)
test_features = test_features.cpu().numpy()


In [None]:
C = 0.3
classifier = LogisticRegression(random_state=0, C=C, max_iter=5000, verbose=1)
classifier.fit(train_features, train_labels)

print('**** Test performance ***')
preds = classifier.predict(test_features)
evaluate(preds, test_labels, test_group_labels, num_groups=num_groups, num_classes=num_classes)

