In [1]:
import sys

import json
import metal
import os
import numpy as np
# Import other dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
# sys.path.append('/dfs/scratch0/zzweng/metal/metal/mmtl/birds')
os.environ['METALHOME'] = '/dfs/scratch0/zzweng/metal'
# Set random seed for notebook
SEED = 123

In [2]:
from skimage import io, transform
import torchvision.transforms as transforms
import numpy as np
from PIL import Image, ImageDraw
DATASET_DIR = '/dfs/scratch0/chami/maskrcnn-benchmark/datasets/traffic_light_data/'
IMAGES_DIR = os.path.join(DATASET_DIR, 'frames')
save_dir = '/dfs/scratch0/chami/maskrcnn-benchmark/datasets/traffic_light_data/annotations/'
tt = transforms.ToTensor()

### Create train/val/test dataset

We want to create a dataset with train/test/val splits such that:
- label =  1 if there is a red traffic light or yellow traffic light
- label = 2 if there is not traffic light or green traffic light

In [1]:
import torch.utils.data as data

splits = {
    'train': ['dayClip1', 'dayClip5', 'dayClip6', 'dayClip7', 
              'dayClip8', 'dayClip9', 'dayClip10', 'dayClip11', 
              'dayClip12', 'dayClip13', 'nightClip2',
              'nightClip3', 'nightClip4', 'nightClip5'],
    'val' : ['dayClip4'],
    'test': ['dayClip2', 'dayClip3', 'nightClip1'],
    'test_night': ['nightClip1'],
    'test_day': ['dayClip2', 'dayClip3'],
}  # not used in this notebook

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    
class Dataset(data.Dataset):
    def __init__(self, split):        
        with open(os.path.join(save_dir, f'{split}.json')) as f:
            annotations = json.load(f)['annotations']
        self.filenames = {}
        self.labels = {}
        self.bbox = {}
        self.is_night = {}
        self.is_day = {}
        self.is_yellow = {}
        for a in annotations:
            img_id = a['image_id']
            self.filenames[img_id] = a['filename']
            if img_id not in self.labels or self.labels[img_id] != 1:
                self.labels[img_id] = 2-int(a['tag'] in ['stop', 'warning'])  # label = 1 if red(stop) or yellow(warning)
            if img_id not in self.bbox:
                self.bbox[img_id] = [a['bbox']]
            else:
                self.bbox[img_id].append(a['bbox'])
            self.is_night[img_id] = int("night" in a['filename'])  # night -> 1
            self.is_day[img_id] = int("day" in a['filename'])  # day -> 1
            self.is_yellow[img_id] = int(a['tag'] == 'warning')  # yellow light -> 1

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        idx += 1 # file IDs starts with 1
        fname = self.filenames[idx]
        image_data = io.imread(os.path.join(IMAGES_DIR, fname))
        image_data = transform.resize(image_data, (224,224,3))  # resize all images to 224x224
        image_data = normalize(tt(image_data).type(torch.float32))
        
        x_dict = {"data": image_data}
        is_night_mask = self.is_night[idx]
        is_day_mask = self.is_day[idx]
        is_yellow_mask = self.is_yellow[idx]
        y_dict = {
            "labelset_gold": torch.tensor([self.labels[idx]]),
            
            "is_night_slice:pred": torch.tensor([self.labels[idx] if is_night_mask else 0]),
            "is_day_slice:pred": torch.tensor([self.labels[idx] if is_day_mask else 0]),
            "is_yellow_slice:pred": torch.tensor([self.labels[idx] if is_yellow_mask else 0]),
            
            "is_night_slice:ind": torch.tensor([1 if is_night_mask else 2]),
            "is_day_slice:ind": torch.tensor([1 if is_day_mask else 2]),
            "is_yellow_slice:ind": torch.tensor([1 if is_yellow_mask else 2]),
        }
        return x_dict, y_dict
    
    def draw(self, image_id):
        fname = self.filenames[image_id]
        print('Visualizing: ', fname)
        img = Image.open(os.path.join(IMAGES_DIR, fname))
        draw = ImageDraw.Draw(img)
        for (left, bottom, right, top) in self.bbox[idx]:
            draw.rectangle([left, bottom, right, top], outline='red', width=3)
        return img

NameError: name 'transforms' is not defined

In [4]:
train_data = Dataset('train')
val_data = Dataset('val')
test_data = Dataset('test')
print(len(train_data), len(val_data), len(test_data))
# train_data.draw(2)

15752 263 1998


In [5]:
print('gt=1 proportions in datasets')
print(sum([p == 1 for p in train_data.labels.values()]) / len(train_data))
print(sum([p == 1 for p in val_data.labels.values()]) / len(val_data))
print(sum([p == 1 for p in test_data.labels.values()]) / len(test_data))

print('\nis_yellow proportions in datasets')
print(sum([p == 1 for p in train_data.is_yellow.values()]) / len(train_data))
print(sum([p == 1 for p in val_data.is_yellow.values()]) / len(val_data))
print(sum([p == 1 for p in test_data.is_yellow.values()]) / len(test_data))

print('\nis_night proportions in datasets')
print(sum([p == 1 for p in train_data.is_night.values()]) / len(train_data))
print(sum([p == 1 for p in val_data.is_night.values()]) / len(val_data))
print(sum([p == 1 for p in test_data.is_night.values()]) / len(test_data))

print('\nis_night proportions in datasets')
print(sum([p == 1 for p in train_data.is_day.values()]) / len(train_data))
print(sum([p == 1 for p in val_data.is_day.values()]) / len(val_data))
print(sum([p == 1 for p in test_data.is_day.values()]) / len(test_data))

gt=1 proportions in datasets
0.5262823768410361
0.0
0.44744744744744747

is_yellow proportions in datasets
0.016315388522092432
0.0
0.023023023023023025

is_night proportions in datasets
0.3009776536312849
0.0
0.24874874874874875

is_night proportions in datasets
0.6990223463687151
1.0
0.7512512512512513


### Initialize the MetalModel and the Payloads

In [16]:
from metal.mmtl.payload import Payload
from metal.mmtl.data import MmtlDataLoader
from pprint import pprint
from metal.mmtl.slicing.tasks import MultiClassificationTask
from metal.mmtl.metal_model import MetalModel 
from metal.mmtl.birds.resnet import *

resnet_model = resnet18(num_classes=2, use_as_feature_extractor=True).float().cuda()
task_name = 'TrafficLightClassificationTask'
task0 = MultiClassificationTask(
    name=task_name, 
    input_module=resnet_model,
    head_module=resnet_model.fc
)
tasks = [task0]
model = MetalModel(tasks, verbose=False)

In [17]:

dl_kwargs = {
    "batch_size": 4
}
payloads = []
splits = ["train", "valid", "test"]
datasets = [train_data, val_data, test_data]
for i, split in enumerate(splits):
    payload_name = f"Payload{i}_{split}"
    labels_to_tasks = {"labelset_gold": task_name, 
                       "is_night_slice:pred": task_name, 
                       "is_day_slice:pred": task_name,
                       "is_yellow_slice:pred": task_name,
                       "is_night_slice:ind": task_name, 
                       "is_day_slice:ind": task_name,
                       "is_yellow_slice:ind": task_name
                      }
    if split == 'train':
        payload = Payload(payload_name, MmtlDataLoader(datasets[i], shuffle=True, **dl_kwargs), labels_to_tasks, split)
    else:
        payload = Payload(payload_name, MmtlDataLoader(datasets[i], shuffle=False, **dl_kwargs), labels_to_tasks, split)
    payloads.append(payload)

pprint(payloads)

[Payload(Payload0_train: labels_to_tasks=[{'labelset_gold': 'TrafficLightClassificationTask', 'is_night_slice:pred': 'TrafficLightClassificationTask', 'is_day_slice:pred': 'TrafficLightClassificationTask', 'is_yellow_slice:pred': 'TrafficLightClassificationTask', 'is_night_slice:ind': 'TrafficLightClassificationTask', 'is_day_slice:ind': 'TrafficLightClassificationTask', 'is_yellow_slice:ind': 'TrafficLightClassificationTask'}], split=train),
 Payload(Payload1_valid: labels_to_tasks=[{'labelset_gold': 'TrafficLightClassificationTask', 'is_night_slice:pred': 'TrafficLightClassificationTask', 'is_day_slice:pred': 'TrafficLightClassificationTask', 'is_yellow_slice:pred': 'TrafficLightClassificationTask', 'is_night_slice:ind': 'TrafficLightClassificationTask', 'is_day_slice:ind': 'TrafficLightClassificationTask', 'is_yellow_slice:ind': 'TrafficLightClassificationTask'}], split=valid),
 Payload(Payload2_test: labels_to_tasks=[{'labelset_gold': 'TrafficLightClassificationTask', 'is_night_sli

In [7]:
# trained using train + val set, best_checkpoint
# model.load_weights(r'/dfs/scratch0/zzweng/metal/logs/2019_05_20/00_08_06/best_model.pth')  # save best checkpoint
model.load_weights(r'/dfs/scratch0/zzweng/metal/logs/2019_05_20/12_19_36/model_checkpoint_4.000253936008126.pth')  # save last checkpoint

### Fine Tuning

In [8]:
for p in payloads:
    p.labels_to_tasks = {"is_day_slice:pred": task_name}
pprint(payloads)

[Payload(Payload0_train: labels_to_tasks=[{'is_day_slice:pred': 'TrafficLightClassificationTask'}], split=train),
 Payload(Payload1_valid: labels_to_tasks=[{'is_day_slice:pred': 'TrafficLightClassificationTask'}], split=valid),
 Payload(Payload2_test: labels_to_tasks=[{'is_day_slice:pred': 'TrafficLightClassificationTask'}], split=test)]


In [24]:
model.score(payloads[0])
model.score(payloads[1])  # naive: 1.0
model.score(payloads[2])  # naive: 0.4090606262491672

KeyboardInterrupt: 

In [9]:
from metal.mmtl.trainer import MultitaskTrainer
trainer = MultitaskTrainer()
scores = trainer.train_model(
    model, 
    payloads, 
    n_epochs=1, 
    log_every=0.1,
    lr=0.001,
    progress_bar=True,
    writer="tensorboard",
    lr_scheduler='linear',
    patience = 10,
    checkpoint_every = 1,
    checkpoint_best = False
)
# fine tuning 
# is_day_slice 19_24_10, 09_39_33
# is_yellow_slice 23_09_52

CONFIG:  {'verbose': True, 'seed': 487084, 'commit_hash': None, 'ami': None, 'progress_bar': True, 'n_epochs': 1, 'l2': 0.0, 'grad_clip': 1.0, 'optimizer_config': {'optimizer': 'adam', 'optimizer_common': {'lr': 0.001}, 'sgd_config': {'momentum': 0.9}, 'adam_config': {'betas': (0.9, 0.999)}, 'rmsprop_config': {}}, 'lr_scheduler': 'linear', 'lr_scheduler_config': {'warmup_steps': 0.0, 'warmup_unit': 'epochs', 'min_lr': 1e-06, 'exponential_config': {'gamma': 0.999}, 'plateau_config': {'factor': 0.5, 'patience': 10, 'threshold': 0.0001}}, 'metrics_config': {'task_metrics': [], 'trainer_metrics': ['model/valid/all/loss'], 'aggregate_metric_fns': [], 'max_valid_examples': 0, 'valid_split': 'valid', 'test_split': 'test'}, 'task_scheduler': 'proportional', 'logger': True, 'logger_config': {'log_unit': 'epochs', 'log_every': 0.1, 'score_every': -1.0, 'log_lr': True}, 'writer': 'tensorboard', 'writer_config': {'log_dir': '/dfs/scratch0/zzweng/metal/logs', 'run_dir': None, 'run_name': None, 'wri

HBox(children=(IntProgress(value=0, max=3938), HTML(value='')))

[0.10 epo]: TrafficLightClassificationTask:[Payload0_train/is_day_slice:pred/loss=4.25e-01, Payload1_valid/is_day_slice:pred/accuracy=9.70e-01] model:[train/all/loss=4.25e-01, train/all/lr=9.00e-04, valid/all/loss=4.55e-02]
[0.20 epo]: TrafficLightClassificationTask:[Payload0_train/is_day_slice:pred/loss=1.84e-01, Payload1_valid/is_day_slice:pred/accuracy=3.27e-01] model:[train/all/loss=1.84e-01, train/all/lr=8.00e-04, valid/all/loss=3.82e+00]
[0.30 epo]: TrafficLightClassificationTask:[Payload0_train/is_day_slice:pred/loss=1.78e-01, Payload1_valid/is_day_slice:pred/accuracy=9.85e-01] model:[train/all/loss=1.78e-01, train/all/lr=7.00e-04, valid/all/loss=5.08e-02]
[0.40 epo]: TrafficLightClassificationTask:[Payload0_train/is_day_slice:pred/loss=8.80e-02, Payload1_valid/is_day_slice:pred/accuracy=1.00e+00] model:[train/all/loss=8.80e-02, train/all/lr=6.00e-04, valid/all/loss=3.43e-03]
[0.50 epo]: TrafficLightClassificationTask:[Payload0_train/is_day_slice:pred/loss=6.95e-02, Payload1_val



KeyboardInterrupt: 

In [22]:
import torch
import dill
# model after fine tuning on is_day_slice
# after 1 epoch
# full_model_path = r'/dfs/scratch0/zzweng/metal/logs/2019_05_21/09_39_33/model_checkpoint_0.9916201117318436.pth'

# after 2 epochs
full_model_path = r'/dfs/scratch0/zzweng/metal/logs/2019_05_20/19_24_10/checkpoints/model.pkl'


model = torch.load(full_model_path)
# torch.save(model, full_model_path, pickle_module=dill)

In [24]:
model.score(payloads[2])

{'TrafficLightClassificationTask/Payload2_test/labelset_gold/accuracy': 0.5770770770770771,
 'TrafficLightClassificationTask/Payload2_test/is_night_slice:pred/accuracy': 0.8309859154929577,
 'TrafficLightClassificationTask/Payload2_test/is_day_slice:pred/accuracy': 0.49300466355762823,
 'TrafficLightClassificationTask/Payload2_test/is_yellow_slice:pred/accuracy': 0.21739130434782608,
 'TrafficLightClassificationTask/Payload2_test/is_night_slice:ind/accuracy': 0.5615615615615616,
 'TrafficLightClassificationTask/Payload2_test/is_day_slice:ind/accuracy': 0.43843843843843844,
 'TrafficLightClassificationTask/Payload2_test/is_yellow_slice:ind/accuracy': 0.7132132132132132}

In [15]:
eval_payload = payloads[2]  # test set
target_tasks = set(eval_payload.labels_to_tasks.values())
target_labels = set(eval_payload.labels_to_tasks.keys())
Ys, Ys_probs, Ys_preds = model.predict_with_gold(eval_payload, target_tasks, target_labels, return_preds=True)

In [14]:
# task_name = 'TrafficLightClassificationTask'
# task_metrics_dict = task0.scorer.score(
#     Ys['is_day_slice:pred'],
#     Ys_probs[task_name],
#     Ys_preds[task_name]
# )
# task_metrics_dict

In [None]:
model.score(payloads[1])  # score on val set

In [19]:
model.score(payloads[2])  # after fine_tuning

{'TrafficLightClassificationTask/Payload2_test/labelset_gold/accuracy': 0.44394394394394393,
 'TrafficLightClassificationTask/Payload2_test/is_night_slice:pred/accuracy': 0.8329979879275654,
 'TrafficLightClassificationTask/Payload2_test/is_day_slice:pred/accuracy': 0.3151232511658894,
 'TrafficLightClassificationTask/Payload2_test/is_yellow_slice:pred/accuracy': 0.06521739130434782,
 'TrafficLightClassificationTask/Payload2_test/is_night_slice:ind/accuracy': 0.4894894894894895,
 'TrafficLightClassificationTask/Payload2_test/is_day_slice:ind/accuracy': 0.5105105105105106,
 'TrafficLightClassificationTask/Payload2_test/is_yellow_slice:ind/accuracy': 0.6351351351351351}

## Error Analysis

In [17]:
bad_idx = np.where(
    np.array(Ys_preds[task_name]).flatten() != np.array(Ys['labelset_gold']).flatten()
)[0] 

bad_idx = list(filter(lambda idx: Ys['is_day_slice:pred'][idx][0] != 0., bad_idx)) # exclude the ones that are abstain

# is_night_slice:pred gives empty bad_idx. makes sense.

In [None]:
len(np.where(np.array(Ys['is_day_slice:pred']).flatten() != 0)[0])

In [None]:
night_idx = np.where(np.array(Ys['is_night_slice:pred']).flatten() != 0)[0]

In [None]:
i = bad_idx[44]
# i = 1043
print('idx', i)

print('Model predicted:', np.array(Ys_preds[task_name]).flatten()[i])
print('      with probs', Ys_probs[task_name][i])
print('is_day_slice:ind label:', np.array(Ys['is_day_slice:ind']).flatten()[i])
print('is_night_slice:ind label:', np.array(Ys['is_night_slice:ind']).flatten()[i])
print('is_yellow_slice:ind label:', np.array(Ys['is_yellow_slice:ind']).flatten()[i])

print('labelset_gold label:', np.array(Ys['labelset_gold']).flatten()[i])
print('is_day_slice:pred label:', np.array(Ys['is_day_slice:pred']).flatten()[i])
print('is_night_slice:pred label:', np.array(Ys['is_night_slice:pred']).flatten()[i])
print('is_yellow_slice:pred label:', np.array(Ys['is_yellow_slice:pred']).flatten()[i])

test_data.draw(i)

In [None]:
# trained using train + test set, best_checkpoint
# model.load_weights(r'/dfs/scratch0/zzweng/metal/logs/2019_05_20/00_35_50/best_model.pth')

In [None]:
# model.score(payloads[2]) # score on test set

In [None]:
# model = torch.load(model_name)
# predictions = torch.tensor(model.predict(payloads[2], task_name=task_name))
# gold_labels = torch.tensor(test_data.labels)

# print((predictions == gold_labels).sum())
# incorrect_predictions = (predictions != gold_labels).nonzero().flatten().tolist()  # get indices of incorrect predictions
# print(len(incorrect_predictions))


In [13]:
from collections import defaultdict
task_losses = defaultdict(float)
task_examples = defaultdict(float)
total_examples = 0
split="valid"
from metal.mmtl.task_scheduler import ProportionalScheduler
task_scheduler = ProportionalScheduler(model, payloads, split)
for batch, payload_name, labels_to_tasks in task_scheduler.get_batches(
    payloads, split
):
    _, Ys = batch
    batch_size = len(next(iter(Ys.values())))
    loss_dict, count_dict = model.calculate_loss(
        *batch, payload_name, labels_to_tasks
    )
    for task_name, loss in loss_dict.items():
        if count_dict[task_name]:
            task_losses[task_name] += loss.item() * count_dict[task_name]
            task_examples[task_name] += count_dict[task_name]
    total_examples += batch_size
    break