In [4]:
import sys

import metal
import os
# Import other dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ['METALHOME'] = '/dfs/scratch1/saelig/slicing/metal/'
# Set random seed for notebook
SEED = 123

In [5]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Load Data


In [6]:
from skimage import io, transform
import torchvision.transforms as transforms
import numpy as np

opj = os.path.join
HOME_DIR = '/dfs/scratch1/saelig/slicing/'
DATASET_DIR = opj(HOME_DIR,'CUB_200_2011')
IMAGES_DIR = opj(DATASET_DIR, 'images')
TENSORS_DIR = opj(HOME_DIR, 'birds_data')
MODELS_DIR = opj(HOME_DIR, 'birds_models')

#Size of eac
#image_list = np.loadtxt(os.path.join(DATASET_DIR, 'images.txt'), dtype=str)
#train_test_split = np.loadtxt(os.path.join(DATASET_DIR, 'train_test_split.txt'), dtype=int)
#labels = np.loadtxt(os.path.join(DATASET_DIR, 'image_class_labels.txt'), dtype=int)


train_image_ids = torch.load(opj(TENSORS_DIR,'train_image_ids.pt'))
valid_image_ids = torch.load(opj(TENSORS_DIR,'valid_image_ids.pt'))
test_image_ids = torch.load(opj(TENSORS_DIR,'test_image_ids.pt'))
X_train = torch.load(opj(TENSORS_DIR,'X_train.pt'))
X_valid = torch.load(opj(TENSORS_DIR,'X_valid.pt'))
X_test = torch.load(opj(TENSORS_DIR,'X_test.pt'))
Y_train = torch.load(opj(TENSORS_DIR,'Y_train.pt'))
Y_valid = torch.load(opj(TENSORS_DIR,'Y_valid.pt'))
Y_test = torch.load(opj(TENSORS_DIR,'Y_test.pt'))



Let's create the payloads. First we need to put the attribute information into an easy to deal with data structure.

In [7]:
attrs_array = np.loadtxt(os.path.join(DATASET_DIR, 'attributes/image_attribute_labels.txt'), usecols=(0,1,2), dtype=int)

Let's create a dictionary to make it easier to figure out which samples have which attributes.

In [8]:
NUM_ATTRIBUTES = 312

#format: <image_id>,  <attribute_id>,  <is_present>

attrs_dict = {} #dict mapping attribute id to a set of image_ids that have that attribute

# for attr in range(1, NUM_ATTRIBUTES + 1):
#     temp = attrs_array[(attrs_array[:, 1] == attr) & (attrs_array[:,2] == 1)]
#     print(temp)
#     break

for (image_id, attr_id, is_present) in attrs_array:
    if is_present == 1:
        if attr_id in attrs_dict:
            attrs_dict[attr_id].add(image_id)
        else:
            attrs_dict[attr_id] = {image_id}

Create payload abstraction for slices based on the binary attributes.

In [13]:
from metal.mmtl.payload import Payload
from metal.mmtl.data import MmtlDataLoader, MmtlDataset
from pprint import pprint

payloads = []
splits = ["train", "valid", "test"]
X_splits = X_train, X_valid, X_test
Y_splits = Y_train, Y_valid, Y_test

task_name = 'BirdClassificationTask'
labels_to_tasks = {"labelset_gold": task_name}

for i in range(3):
    payload_name = f"Payload{i}_{splits[i]}"
    X_dict = {'data': X_splits[i]}
    Y_dict = {'labelset_gold': Y_splits[i]}
    
    if splits[i] == 'train':
        image_ids = train_image_ids
    elif splits[i] == 'valid':
        image_ids = valid_image_ids
    else:
        image_ids = test_image_ids
    for attr_id in range(1, NUM_ATTRIBUTES + 1):
        f = lambda x: 1 if x in attrs_dict[attr_id] else 0
        mask = list(map(f, image_ids.tolist()))
        if splits[i] == 'train':
            print('Proportion of attribute {} in train set: {}'.format(attr_id, sum(mask)/len(mask)))
        mask = torch.tensor(mask)
        slice_labelset_name = f"labelset:{attr_id}:pred"
        slice_task_name = f"{task_name}:{attr_id}:pred"
        Y_dict[slice_labelset_name] = mask * Y_splits[i]
        labels_to_tasks[slice_labelset_name] = task_name
        
        mask[mask == 0] = 2 #to follow Metal convention
        slice_labelset_name = f"labelset:{attr_id}:ind"
        slice_task_name = f"{task_name}:{attr_id}:ind"
        Y_dict[slice_labelset_name] = mask 
        labels_to_tasks[slice_labelset_name] = None
        
    
    dataset = MmtlDataset(X_dict, Y_dict)
    data_loader = MmtlDataLoader(dataset, batch_size=32)
    payload = Payload(payload_name, data_loader, labels_to_tasks, splits[i])
    payloads.append(payload)


Let's load our baseline model

In [14]:
model = torch.load(opj(MODELS_DIR,'resnet18_lr_1e-3_patience10.pt')) #achieves 17% accuracy on test set

In [16]:
accs_per_slice = model.score(payloads[1], metrics=[]) #score on validation set

In [17]:
accs_per_slice_list = list(accs_per_slice.items())
accs_per_slice_list = list(map(lambda p: (p[0].split(':')[1], p[1]) if ':' in p[0] else p, accs_per_slice_list))
print(sorted(accs_per_slice_list, key=lambda x: x[1]))

[('19', 0.0), ('34', 0.0), ('49', 0.0), ('115', 0.0), ('130', 0.0), ('138', 0.0), ('144', 0.0), ('145', 0.0), ('162', 0.0), ('171', 0.0), ('177', 0.0), ('267', 0.0), ('271', 0.0), ('281', 0.0), ('286', 0.0), ('287', 0.0), ('175', 0.05555555555555555), ('242', 0.05660377358490566), ('190', 0.06521739130434782), ('246', 0.06557377049180328), ('87', 0.06666666666666667), ('17', 0.07017543859649122), ('126', 0.07142857142857142), ('176', 0.07142857142857142), ('129', 0.07692307692307693), ('235', 0.08), ('32', 0.08064516129032258), ('140', 0.08333333333333333), ('279', 0.08333333333333333), ('301', 0.08333333333333333), ('66', 0.08620689655172414), ('56', 0.08641975308641975), ('256', 0.08823529411764706), ('128', 0.09090909090909091), ('192', 0.09090909090909091), ('252', 0.09090909090909091), ('94', 0.09166666666666666), ('81', 0.09574468085106383), ('47', 0.0967741935483871), ('160', 0.0967741935483871), ('274', 0.09917355371900827), ('89', 0.1), ('135', 0.1), ('161', 0.1), ('178', 0.1)