In [1]:
%load_ext autoreload
%autoreload 2

# Introduction and Objective
## Training with MultiTask Learning

In [None]:
import h5py
import numpy as np
import yaml
import os
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd

In [None]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
from utils import (
    HDF5MultitaskDataset,
    ResizeTransform, 
    MultitaskCollator,
    MultiTaskLandmarkUNet1,
)

# load parameters

In [19]:
with open("../../code_configs/params.yaml") as f:
    params = yaml.safe_load(f)

# Load metadata table

In [None]:
metadata_table = pd.read_hdf(
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], params['METADATA_TABLE_NAME']),
    key='df',
)

In [None]:
metadata_table.head()

# DataLoader for task one: Input Image Reconstruction

In [None]:
# define the task id
task_id = 1

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') , ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(params['TARGET_IMAGE_SIZE'])),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_one = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [None]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_one):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

# DataLoader for task two: Edge Detection

In [None]:
# define the task id
task_id = 2

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') & (metadata_table['edges_present']==True), ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(params['TARGET_IMAGE_SIZE'])),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_two = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [None]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_two):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

# DataLoader for task three: Vertebral Landmark Detection

In [None]:
# define the task id
task_id = 3

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') & (metadata_table['v_annots_present']==True), ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(params['TARGET_IMAGE_SIZE'])),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_three = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [None]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_three):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

# DataLoader for task four: Facial Landmark Detection

In [None]:
# define the task id
task_id = 4

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') & (metadata_table['f_annots_present']==True), ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(params['TARGET_IMAGE_SIZE'])),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_four = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [None]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_four):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

# Model 

In [None]:
model = MultiTaskLandmarkUNet1(
    in_channels=3,
    out_channels1=1,
    out_channels2=1,
    out_channels3=13,
    out_channels4=19,
    enc_chan_multiplier=1,
    dec_chan_multiplier=1,
)

In [None]:
# count the number of trainable parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of trainable parameters: ", num_params)

# Test Pytorch Lightning

In [14]:
from utils import (
    trainer_multitask_v_and_f_landmarks,
    trainer_v_landmarks,
    create_dataloader,
    SingletaskTrainLandmarks,
)
import pytorch_lightning as pl

In [16]:
trainer_v_landmarks()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type                   | Params
-----------------------------------------------------
0 | model     | MultiTaskLandmarkUNet1 | 407 K 
1 | train_mse | MeanSquaredError       | 0     
2 | val_mse   | MeanSquaredError       | 0     
-----------------------------------------------------
407 K     Trainable params
0         Non-trainable params
407 K     Total params
1.629     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|                       | 0/2 [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [64, 192, 3, 3], expected input[16, 96, 32, 32] to have 192 channels, but got 96 channels instead

In [5]:
train_dataloader = create_dataloader(
    task_id=3,
    batch_size=16,
    split='train',
)

In [6]:
for i, batch in enumerate(train_dataloader):
    print(type(batch))
    x, y = batch
    break

<class 'dict'>


In [8]:
batch.keys()

dict_keys(['image', 'v_landmarks'])

In [9]:
x

'image'

In [12]:
x, y = batch.values()

In [13]:
x

tensor([[[[253.7871, 253.6554, 253.9746,  ..., 253.9746, 253.9746, 254.4160],
          [252.9747, 252.8835, 252.9747,  ..., 253.9746, 253.9746, 254.4160],
          [249.9750, 249.9750, 250.9749,  ..., 253.9746, 253.9746, 254.4160],
          ...,
          [253.9746, 253.9746, 253.9746,  ..., 197.6548, 199.6659, 205.4128],
          [253.9746, 253.9746, 253.9746,  ..., 201.7224, 201.3126, 203.6847],
          [253.7871, 253.7871, 253.7871,  ..., 200.9758, 202.0284, 208.5376]]],


        [[[ 93.9554, 121.1773, 113.0942,  ...,  57.0187,  58.1670,  59.7909],
          [129.1424, 140.2340, 132.4213,  ...,  69.9930,  70.9929,  71.9928],
          [181.6156, 178.7370, 177.6502,  ...,  69.9930,  69.9930,  72.9927],
          ...,
          [142.5150, 142.2934, 134.4211,  ...,  69.9930,  68.9931,  68.9931],
          [141.9721, 127.4462, 136.4912,  ...,  65.6917,  65.0999,  65.9299],
          [151.4224, 127.7372, 146.5195,  ...,  58.5635,  60.7273,  59.6454]]],


        [[[240.9759, 239.9