In [1]:
import sys
import metal
import os
# Import other dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F

# Set random seed for notebook
SEED = 123

In [75]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load Data

Here, the train/test split was defined in the dataset. We then split the train set into a train/valid (see next cell)

In [41]:
from skimage import io, transform
import numpy as np
DATASET_DIR = '/lfs/1/saelig/CUB_200_2011/'
IMAGES_DIR = os.path.join(DATASET_DIR, 'images')

#Size of eac
image_list = np.loadtxt(os.path.join(DATASET_DIR, 'images.txt'), dtype=str)
train_test_split = np.loadtxt(os.path.join(DATASET_DIR, 'train_test_split.txt'), dtype=int)
labels = np.loadtxt(os.path.join(DATASET_DIR, 'image_class_labels.txt'), dtype=int)

X = []
Y = []
X_test = []
Y_test = []

#image size (332, 500, 3)

for image_id, image_file in image_list:
    image_id = int(image_id)
    image_data = io.imread(os.path.join(IMAGES_DIR, image_file))
    image_data = transform.resize(image_data, (32,32,3)) #resize all images to 32x32
    label = labels[image_id - 1][1]
    if train_test_split[image_id - 1][1] == 1: #put in train
        X.append(image_data)
        Y.append(label)
    else: #put in test
        X_test.append(image_data)
        Y_test.append(label)

X_train = np.stack(X)
Y_train = np.array(Y)
X_test = np.stack(X_test)
Y_test = np.array(Y_test)

Now let's convert all the data to numpy arrays, and create a validation set.

In [67]:
from metal.utils import split_data
X_train = np.stack(X)
Y_train = np.array(Y)

(X_train, X_valid), (Y_train, Y_valid) = split_data(X_train, Y_train, splits=[0.8,0.2], seed=SEED)
X_test = np.stack(X_test)
Y_test = np.array(Y_test)

Create a task. Use a resnet50 for now as our input model. Since the resnet already includes the fully connected layer, we don't specify a `head_module`, which defaults to the identity.

In [87]:
from torchvision.models.resnet import *
from metal.mmtl.slicing.tasks import MultiClassificationTask
from metal.mmtl.metal_model import MetalModel 

model = resnet50(num_classes=200)

task0 = MultiClassificationTask(
    name='BirdClassificationTaask', 
    input_module=resnet50, 
)
tasks = [task0]
model = MetalModel(tasks, verbose=False)

Create payload abstraction for our train/valid/test sets.

In [96]:
from metal.mmtl.payload import Payload

payloads = []
splits = ["train", "valid", "test"]
X_splits = X_train, X_valid, X_test
Y_splits = Y_train, Y_valid, Y_test

for i in range(3):
    payload_name = f"Payload{i}_{splits[i]}"
    task_name = task0.name
    payload = Payload.from_tensors(payload_name, X_splits[i], Y_splits[i], task_name, splits[i], batch_size=32)
    payloads.append(payload)