In [14]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Introduction and Objective
## Training with MultiTask Learning

In [15]:
import h5py
import numpy as np
import yaml
import os
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd

In [16]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [17]:
from utils import (
    HDF5MultitaskDataset,
    ResizeTransform, 
    MultitaskCollator,
    MultiTaskLandmarkUNet1,
)

# load parameters

In [18]:
with open("../../code_configs/params.yaml") as f:
    params = yaml.safe_load(f)

# Load metadata table

In [19]:
metadata_table = pd.read_hdf(
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], params['METADATA_TABLE_NAME']),
    key='df',
)

In [20]:
metadata_table.head()

Unnamed: 0,source_image_filename,harmonized_id,dataset,dev_set,v_annots_present,f_annots_present,edges_present,split
0,45.jpg,041281ee7fb89f6835a71c309b3b503e3d5a68fc46a608...,dataset_1,,True,False,True,undefined
1,92.jpg,2cfa37a69916c8a45a51bb8beeb04425e07d2a22f694e0...,dataset_1,,True,False,True,undefined
2,43.jpg,7201dc2be0b97f59a7901004d6496bbe84c440530776db...,dataset_1,,True,False,True,undefined
3,7.jpg,2cd4487c03c72d1016ea0a72d1b21eb987878c90ae9eff...,dataset_1,,True,False,True,undefined
4,121.jpg,27624a6eb37bbc8aafabe2075f423d573b189eae6f23fb...,dataset_1,,True,False,True,undefined


# DataLoader for task one: Input Image Reconstruction

In [48]:
# define the task id
task_id = 1

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') , ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(params['TARGET_IMAGE_SIZE'])),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_one = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [53]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_one):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

-- Sanity check dataset object!
dict_keys(['image'])

image
torch.Size([1, 256, 256])

-- Sanity check dataloader object!
batch_ndx  0

image
torch.Size([4, 1, 256, 256])


# DataLoader for task two: Edge Detection

In [62]:
# define the task id
task_id = 2

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') & (metadata_table['edges_present']==True), ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(params['TARGET_IMAGE_SIZE'])),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_two = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [63]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_two):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

-- Sanity check dataset object!
dict_keys(['image', 'edges'])

image
torch.Size([1, 256, 256])

edges
torch.Size([1, 256, 256])

-- Sanity check dataloader object!
batch_ndx  0

image
torch.Size([4, 1, 256, 256])

edges
torch.Size([4, 1, 256, 256])


# DataLoader for task three: Vertebral Landmark Detection

In [64]:
# define the task id
task_id = 3

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') & (metadata_table['v_annots_present']==True), ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(params['TARGET_IMAGE_SIZE'])),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_three = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [65]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_three):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

-- Sanity check dataset object!
dict_keys(['image', 'v_landmarks'])

image
torch.Size([1, 256, 256])

v_landmarks
torch.Size([13, 2])

-- Sanity check dataloader object!
batch_ndx  0

image
torch.Size([4, 1, 256, 256])

v_landmarks
torch.Size([4, 13, 2])


# DataLoader for task four: Facial Landmark Detection

In [67]:
# define the task id
task_id = 4

# create the right list of paths
train_file_list = metadata_table.loc[
    (metadata_table['split']=='train') & (metadata_table['f_annots_present']==True), ['harmonized_id']
].to_numpy().ravel().tolist()

train_file_list = [
    os.path.join(params['PRIMARY_DATA_DIRECTORY'], file_path+'.hdf5') for file_path in train_file_list
]

# instantiate the transforms
my_transforms = transforms.Compose([
    ResizeTransform(tuple(params['TARGET_IMAGE_SIZE'])),
])

# instantiate the dataset and dataloader objects
train_dataset = HDF5MultitaskDataset(
    file_paths=train_file_list,
    task_id=task_id,
    transforms=my_transforms,
)
collator_task = MultitaskCollator(
    task_id=task_id,
)
dataloader_four = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator_task
)

In [68]:
# sanity check dataset and dataloader

# dataset
print("-- Sanity check dataset object!")
dataset_iter = iter(train_dataset)
for batch in dataset_iter:
    print(batch.keys())
    for k, v in batch.items():
        print()
        print(k,)
        print(v.shape)
    break

print()

# data loader
print("-- Sanity check dataloader object!")
for batch_ndx, sample in enumerate(dataloader_four):
    print("batch_ndx ", batch_ndx)
    for k, v in sample.items():
        print()
        print(k,)
        print(v.shape)
    break

-- Sanity check dataset object!
dict_keys(['image', 'f_landmarks'])

image
torch.Size([1, 256, 256])

f_landmarks
torch.Size([19, 2])

-- Sanity check dataloader object!
batch_ndx  0

image
torch.Size([4, 1, 256, 256])

f_landmarks
torch.Size([4, 19, 2])


# Model 

In [25]:
model = MultiTaskLandmarkUNet1(
    in_channels=3,
    out_channels1=1,
    out_channels2=1,
    out_channels3=13,
    out_channels4=19,
    enc_chan_multiplier=1,
    dec_chan_multiplier=1,
)

In [26]:
# count the number of trainable parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of trainable parameters: ", num_params)

Total number of trainable parameters:  407442


# Test Pytorch Lightning

In [1]:
import torch
from torch import nn
import pytorch_lightning as pl


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.shared_fc = nn.Linear(10, 20)
        self.output_layer1 = nn.Linear(20, 2)
        self.output_layer2 = nn.Linear(20, 3)

    def forward(self, x, task_id):
        x = torch.relu(self.shared_fc(x))
        if task_id == 1:
            print("task_id ", task_id)
            output1 = self.output_layer1(x)
            return output1
        elif task_id == 2:
            print("task_id ", task_id)
            output2 = self.output_layer2(x)
            return output2


class TwoLayerNetworkTask1(pl.LightningModule):
    def __init__(self, my_model, task_id):
        super().__init__()
        self.my_model = my_model
        self.task_id = task_id

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.to(torch.float32)
        output = self.my_model(x, task_id=self.task_id)
        loss = nn.MSELoss()(output, y)
        self.log(f'train_loss_{self.task_id}', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

    
class TwoLayerNetworkTask2(pl.LightningModule):
    def __init__(self, my_model, task_id):
        super().__init__()
        self.my_model = my_model
        self.task_id = task_id

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.to(torch.float32)
        output = self.my_model(x, task_id=self.task_id)
        loss = nn.MSELoss()(output, y)
        self.log(f'train_loss_{self.task_id}', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [2]:
from torch.utils.data import Dataset, DataLoader

class DummyDataset(Dataset):
    def __init__(self, num_samples, input_size, task_id):
        self.num_samples = num_samples
        self.input_size = input_size
        self.task_id = task_id
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        x = torch.randn(self.input_size)
        if self.task_id == 1:
            y1 = torch.randint(0, 2, (2,))
            return x, y1
        elif self.task_id == 2:
            y2 = torch.randint(0, 3, (3,))
            return x, y2

In [3]:

# Create the dummy dataset
dataset1 = DummyDataset(num_samples=100, input_size=10, task_id=1)
dataset2 = DummyDataset(num_samples=100, input_size=10, task_id=2)

# Create the dataloader
batch_size1 = 8
batch_size2 = 3
dataloader1 = DataLoader(dataset1, batch_size=batch_size1, shuffle=True)
dataloader2 = DataLoader(dataset2, batch_size=batch_size2, shuffle=True)

In [6]:
model = MyModel()
# Create instances of the PyTorch Lightning modules with the same model
module1 = TwoLayerNetworkTask1(model, task_id=1)
module2 = TwoLayerNetworkTask2(model, task_id=2)

# Create the Trainer objects and train the modules
trainer1 = pl.Trainer(limit_train_batches=2, max_epochs=2, log_every_n_steps=1)
trainer1.fit(module1, dataloader1)

trainer2 = pl.Trainer(limit_train_batches=5, max_epochs=2, log_every_n_steps=1)
trainer2.fit(module2, dataloader2)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/samehr/Desktop/cephal/cvmt/notebooks/ml/lightning_logs

  | Name     | Type    | Params
-------------------------------------
0 | my_model | MyModel | 325   
-------------------------------------
325       Trainable params
0         Non-trainable params
325       Total params
0.001     Total estimated model params size (MB)


Epoch 0:   0%|                                                                                                                     | 0/2 [00:00<?, ?it/s]task_id  1
Epoch 0:  50%|██████████████████████████████████████████████████                                                  | 1/2 [00:00<00:00, 60.06it/s, v_num=0]task_id  1
Epoch 1:   0%|                                                                                                            | 0/2 [00:00<?, ?it/s, v_num=0]task_id  1
Epoch 1:  50%|██████████████████████████████████████████████████                                                  | 1/2 [00:00<00:00, 62.85it/s, v_num=0]task_id  1
Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 71.08it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 55.36it/s, v_num=0]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type    | Params
-------------------------------------
0 | my_model | MyModel | 325   
-------------------------------------
325       Trainable params
0         Non-trainable params
325       Total params
0.001     Total estimated model params size (MB)



Epoch 0:   0%|                                                                                                                     | 0/5 [00:00<?, ?it/s]task_id  2
Epoch 0:  20%|████████████████████                                                                                | 1/5 [00:00<00:00, 85.28it/s, v_num=1]task_id  2
Epoch 0:  40%|████████████████████████████████████████                                                            | 2/5 [00:00<00:00, 68.01it/s, v_num=1]task_id  2
Epoch 0:  60%|████████████████████████████████████████████████████████████                                        | 3/5 [00:00<00:00, 78.92it/s, v_num=1]task_id  2
Epoch 0:  80%|████████████████████████████████████████████████████████████████████████████████                    | 4/5 [00:00<00:00, 74.09it/s, v_num=1]task_id  2
Epoch 1:   0%|                                                                                                            | 0/5 [00:00<?, ?it/s, v_num=1]task_id  2
Epoch 1:  20%|█

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 62.19it/s, v_num=1]
