In [1]:
%load_ext autoreload
%autoreload 2

# Introduction and Objective
## Training with MultiTask Learning

In [2]:
import h5py
import numpy as np
import yaml
import os
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd

In [3]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [4]:
from utils import (
    HDF5MultitaskDataset,
    ResizeTransform, 
    MultitaskCollator,
    MultiTaskLandmarkUNetCustom,
    nested_dict_to_easydict,
    Coord2HeatmapTransform,
    CustomToTensor,
)

  from .autonotebook import tqdm as notebook_tqdm


# load parameters

In [5]:
with open("../../code_configs/params.yaml") as f:
    PARAMS = yaml.safe_load(f)
    PARAMS = nested_dict_to_easydict(PARAMS)

# Load metadata table

In [6]:
metadata_table = pd.read_hdf(
    os.path.join(PARAMS.PRIMARY_DATA_DIRECTORY, PARAMS.TRAIN.METADATA_TABLE_NAME),
    key='df',
)

In [7]:
metadata_table.head()

Unnamed: 0,v_annots_present,f_annots_present,edges_present,f_annots_rows,f_annots_cols,harmonized_id,v_annots_2_rows,v_annots_2_cols,v_annots_3_rows,v_annots_3_cols,v_annots_4_rows,v_annots_4_cols,source_image_filename,dataset,dev_set,valid,split
0,True,False,True,,,041281ee7fb89f6835a71c309b3b503e3d5a68fc46a608...,3.0,2.0,5.0,2.0,5.0,2.0,45.jpg,dataset_1,,True,undefined
1,True,False,True,,,2cfa37a69916c8a45a51bb8beeb04425e07d2a22f694e0...,3.0,2.0,5.0,2.0,5.0,2.0,92.jpg,dataset_1,,True,undefined
2,True,False,True,,,7201dc2be0b97f59a7901004d6496bbe84c440530776db...,3.0,2.0,5.0,2.0,5.0,2.0,43.jpg,dataset_1,,True,undefined
3,True,False,True,,,2cd4487c03c72d1016ea0a72d1b21eb987878c90ae9eff...,3.0,2.0,5.0,2.0,5.0,2.0,7.jpg,dataset_1,,True,undefined
4,True,False,True,,,27624a6eb37bbc8aafabe2075f423d573b189eae6f23fb...,3.0,2.0,5.0,2.0,5.0,2.0,121.jpg,dataset_1,,True,undefined


# DataLoader for task one: Input Image Reconstruction

In [8]:
# define the task id
task_id = 1

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') , ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(PARAMS.PRIMARY_DATA_DIRECTORY, file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(PARAMS.TRAIN.TARGET_IMAGE_SIZE)),
    Coord2HeatmapTransform(
        tuple(PARAMS.TRAIN.TARGET_IMAGE_SIZE),
        PARAMS.TRAIN.GAUSSIAN_COORD2HEATMAP_STD
    ),
    CustomToTensor(),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_one = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [9]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_one):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

-- Sanity check dataset object!
dict_keys(['image'])

image
torch.Size([1, 256, 256])

-- Sanity check dataloader object!
batch_ndx  0

image
torch.Size([4, 1, 256, 256])


# DataLoader for task two: Edge Detection

In [10]:
# define the task id
task_id = 2

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') & (metadata_table['edges_present']==True), ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(PARAMS.PRIMARY_DATA_DIRECTORY, file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(PARAMS.TRAIN.TARGET_IMAGE_SIZE)),
    Coord2HeatmapTransform(
        tuple(PARAMS.TRAIN.TARGET_IMAGE_SIZE),
        PARAMS.TRAIN.GAUSSIAN_COORD2HEATMAP_STD
    ),
    CustomToTensor(),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_two = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [11]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_two):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

-- Sanity check dataset object!
dict_keys(['image', 'edges'])

image
torch.Size([1, 256, 256])

edges
torch.Size([1, 256, 256])

-- Sanity check dataloader object!
batch_ndx  0

image
torch.Size([4, 1, 256, 256])

edges
torch.Size([4, 1, 256, 256])


# DataLoader for task three: Vertebral Landmark Detection

In [12]:
# define the task id
task_id = 3

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') & (metadata_table['v_annots_present']==True), ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(PARAMS.PRIMARY_DATA_DIRECTORY, file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(PARAMS.TRAIN.TARGET_IMAGE_SIZE)),
    Coord2HeatmapTransform(
        tuple(PARAMS.TRAIN.TARGET_IMAGE_SIZE),
        PARAMS.TRAIN.GAUSSIAN_COORD2HEATMAP_STD
    ),
    CustomToTensor(),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_three = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [13]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_three):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

-- Sanity check dataset object!
dict_keys(['image', 'v_landmarks'])

image
torch.Size([1, 256, 256])

v_landmarks
torch.Size([13, 256, 256])

-- Sanity check dataloader object!
batch_ndx  0

image
torch.Size([4, 1, 256, 256])

v_landmarks
torch.Size([4, 13, 256, 256])


# DataLoader for task four: Facial Landmark Detection

In [14]:
# define the task id
task_id = 4

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') & (metadata_table['f_annots_present']==True), ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(PARAMS.PRIMARY_DATA_DIRECTORY, file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(PARAMS.TRAIN.TARGET_IMAGE_SIZE)),
    Coord2HeatmapTransform(
        tuple(PARAMS.TRAIN.TARGET_IMAGE_SIZE),
        PARAMS.TRAIN.GAUSSIAN_COORD2HEATMAP_STD
    ),
    CustomToTensor(),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_four = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [15]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_four):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

-- Sanity check dataset object!
dict_keys(['image', 'f_landmarks'])

image
torch.Size([1, 256, 256])

f_landmarks
torch.Size([19, 256, 256])

-- Sanity check dataloader object!
batch_ndx  0

image
torch.Size([4, 1, 256, 256])

f_landmarks
torch.Size([4, 19, 256, 256])


# Model 

In [16]:
model = MultiTaskLandmarkUNetCustom(
    in_channels=1,
    out_channels1=1,
    out_channels2=1,
    out_channels3=13,
    out_channels4=19,
    enc_chan_multiplier=1,
    dec_chan_multiplier=1,
    backbone_encoder="efficientnet-b4",
    backbone_weights="imagenet",
    freeze_backbone=True,
)

In [17]:
model_params = PARAMS.MODEL.PARAMS
model = MultiTaskLandmarkUNetCustom(**model_params)

In [18]:
# count the number of trainable parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of trainable parameters: ", num_params)

Total number of trainable parameters:  450266


In [19]:
image = torch.randn(1, 1, 256, 256)
image /= image.max()

In [20]:
out = model(image, task_id=3)
print(out.shape)

torch.Size([1, 13, 256, 256])


# Test Pytorch Lightning

In [21]:
from utils import (
    trainer_v_landmarks_single_task,
)
import pytorch_lightning as pl
import torch

In [22]:
trainer_v_landmarks_single_task()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type                        | Params
----------------------------------------------------------
0 | model     | MultiTaskLandmarkUNetCustom | 18.0 M
1 | train_mse | MeanSquaredError            | 0     
2 | val_mse   | MeanSquaredError            | 0     
----------------------------------------------------------
450 K     Trainable params
17.5 M    Non-trainable params
18.0 M    Total params
71.992    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                                                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████████████████████████████████████████████████████████████| 17/17 [01:44<00:00,  6.12s/it, v_num=11]                                 
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                   | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                      | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|██████████▎                                                   | 1/6 [00:05<00:27,  5.54s/it][A
Validation DataLoader 0:  33%|████████████████████▋                                         | 2/6 [00:10<00:21,  5.49s/it][A
Validation DataLoader 0:  50%|███████████████████████████████                               | 3/6 [00:16<00:16,  5.61s/it][A
Validation DataLoader 0:  67%|█████████████████████████████████████████▎                    | 4/6 [00:21<00:10,  5.35s/it][A
Validation DataLoader 0:  83%|████████████████████████

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [25]:
from utils import create_dataloader

task_config = PARAMS.TRAIN.SINGLE_TASK
task_id = task_config.TASK_ID
batch_size = task_config.BATCH_SIZE

train_dataloader = create_dataloader(
    task_id=task_id,
    batch_size=batch_size,
    split='train',
    shuffle=False,
)

In [26]:
for i_batch, sample_batched in enumerate(train_dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['v_landmarks'].size())

0 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
1 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
2 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
3 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
4 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
5 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
6 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
7 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
8 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
9 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
10 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
11 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
12 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
13 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
14 torch.Size([16, 1, 256, 256]) torch.Size([16, 13, 256, 256])
15 torch.Size([16, 1, 256, 256]) torch.Size([16, 1