# Multi-output model

<img src="./img/multi_task_vs_multi_label.png" alt="multi_task_vs_multi_label" style="width: 600px;"/>

Multi-task model vs Multi-label classification model.

## Dataset

Omniglot dataset - collection of images of 964 hand-written characters from 30 alphabets.

In [21]:
import numpy as np
import pandas as pd

from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import Omniglot
from torchvision import transforms

In [5]:
import os

def find_classes(root_dir):
    retour = []
    for (root, dirs, files) in os.walk(root_dir):
        for f in files:
            if (f.endswith("png")):
                r = root.split('/')
                lr = len(r)
                retour.append((r[lr - 2], r[lr - 2] + "/" + r[lr - 1], root + "/" + f))
    print("== Found %d items " % len(retour))
    return retour

In [6]:
img_samples = find_classes('./data/omniglot-py/')

== Found 19280 items 


In [7]:
img_samples[0]

('Japanese_(hiragana)',
 'Japanese_(hiragana)/character05',
 './data/omniglot-py/images_background/Japanese_(hiragana)/character05/0492_07.png')

In [8]:
alphabet_codes, alphabet = pd.Series([i[0] for i in img_samples]).factorize()

In [9]:
alphabet

Index(['Japanese_(hiragana)', 'Inuktitut_(Canadian_Aboriginal_Syllabics)',
       'Malay_(Jawi_-_Arabic)', 'Ojibwe_(Canadian_Aboriginal_Syllabics)',
       'N_Ko', 'Korean', 'Futurama', 'Arcadian', 'Sanskrit', 'Grantha',
       'Burmese_(Myanmar)', 'Early_Aramaic', 'Greek', 'Cyrillic', 'Tifinagh',
       'Latin', 'Bengali', 'Balinese', 'Braille', 'Tagalog', 'Gujarati',
       'Japanese_(katakana)', 'Anglo-Saxon_Futhorc', 'Asomtavruli_(Georgian)',
       'Mkhedruli_(Georgian)', 'Hebrew', 'Alphabet_of_the_Magi',
       'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Armenian',
       'Syriac_(Estrangelo)'],
      dtype='object')

In [10]:
labels_codes, labels = pd.Series([i[1] for i in img_samples]).factorize()

In [13]:
len(labels), len(alphabet)

(964, 30)

In [14]:
samples = [[img_path[2], alphabet, label] for img_path, alphabet, label in zip(img_samples, alphabet_codes, labels_codes)]

In [15]:
samples[0]

['./data/omniglot-py/images_background/Japanese_(hiragana)/character05/0492_07.png',
 0,
 0]

In [16]:
# custom Dataset class
class OmniglotDataset(Dataset):
    def __init__(self, transform, samples):
        self.transform = transform
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label_alph, label_char = self.samples[idx]
        img = Image.open(img_path).convert('L')
        img = self.transform(img)
        return img, label_alph, label_char

In [20]:
dataset_train = OmniglotDataset(
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
    ]),
    samples=samples,
)

dataloader_train = DataLoader(
    dataset_train, shuffle=True, batch_size=32,
)

# Model

Building two-output model to predict character and alphabet

<img src="./img/multi_output.png" alt="multi_output" style="width: 600px;"/>

In [18]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()        
        self.image_layer = nn.Sequential(            
            nn.Conv2d(1, 16, kernel_size=3, padding=1),            
            nn.MaxPool2d(kernel_size=2),            
            nn.ELU(),            
            nn.Flatten(),            
            nn.Linear(16*32*32, 128)        
        )              
        self.class_alphabet = nn.Linear(128, 30)
        self.class_character = nn.Linear(128, 964)
        
    def forward(self, x):
        x = self.image_layer(x)
        out_alphabet = self.class_alphabet(x)
        out_character = self.class_character(x)
        return out_alphabet,out_character
        

## Training Loop

In [23]:
import torch.optim as optim

In [21]:
net = Net()

In [114]:
loss_function = nn.CrossEntropyLoss()

In [24]:
optimizer = optim.SGD(net.parameters(), lr=0.05)

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    for img, lables_alph, labels_char in dataloader_train:
        optimizer.zero_grad()
        out_alph, out_char = net(img)
        loss_alph = loss_function(out_alph, lables_alph)
        loss_char = loss_function(out_char, labels_char)
        total_loss = 0.3 * loss_alph + 0.7 * loss_char # add weight to loss - optimization 
        total_loss.backward()
        optimizer.step()

## Loss weight

The loss funtions ideally should be on the same scale.

If scale is different, there will be penalization of smaller scale loss.

The solution: normalize losses before weighting and adding.

In [None]:
loss_price = loss_price / torch.max(loss_price) # MSE oss
loss_quality = loss_quality / torch.max(loss_quality) # CrossEntropy loss
total = 0.3 * loss_price + 0.7 * loss_quality

## Evaluation

In [23]:
from torchmetrics import Accuracy

In [25]:
# set metric for each output
acc_alph = Accuracy(
    task='multiclass', num_classes=30
)
acc_char = Accuracy(
    task='multiclass', num_classes=964
)

In [None]:
# evaluation loop
net.eval()
with torch.no_grad():
    for img, lables_alph, labels_char in dataloader_test:
        out_alph, out_char = net(img)
        _, pred_alph = torch.max(out_alph, 1)
        _, pred_char = torch.max(out_char, 1)
        acc_alph(pred_alph, lables_alph)
        acc_char(pred_char, labels_char)

acc_alph.compute()
acc_char.compute()  