# Multi-input model

## Dataset

Omniglot dataset - collection of images of 964 hand-written characters from 30 alphabets.

In [28]:
import numpy as np
import pandas as pd

from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import Omniglot
from torchvision import transforms

In [122]:
import os

def find_classes(root_dir):
    retour = []
    for (root, dirs, files) in os.walk(root_dir):
        for f in files:
            if (f.endswith("png")):
                r = root.split('/')
                lr = len(r)
                retour.append((r[lr - 2], r[lr - 2] + "/" + r[lr - 1], root + "/" + f))
    print("== Found %d items " % len(retour))
    return retour

In [123]:
img_samples = find_classes('./data/omniglot-py/')

== Found 19280 items 


In [125]:
img_samples[0]

('Japanese_(hiragana)',
 'Japanese_(hiragana)/character05',
 './data/omniglot-py/images_background/Japanese_(hiragana)/character05/0492_07.png')

In [126]:
alphabet_codes, alphabet = pd.Series([i[0] for i in img_samples]).factorize()

In [127]:
alphabet

Index(['Japanese_(hiragana)', 'Inuktitut_(Canadian_Aboriginal_Syllabics)',
       'Malay_(Jawi_-_Arabic)', 'Ojibwe_(Canadian_Aboriginal_Syllabics)',
       'N_Ko', 'Korean', 'Futurama', 'Arcadian', 'Sanskrit', 'Grantha',
       'Burmese_(Myanmar)', 'Early_Aramaic', 'Greek', 'Cyrillic', 'Tifinagh',
       'Latin', 'Bengali', 'Balinese', 'Braille', 'Tagalog', 'Gujarati',
       'Japanese_(katakana)', 'Anglo-Saxon_Futhorc', 'Asomtavruli_(Georgian)',
       'Mkhedruli_(Georgian)', 'Hebrew', 'Alphabet_of_the_Magi',
       'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Armenian',
       'Syriac_(Estrangelo)'],
      dtype='object')

In [131]:
len(alphabet)

30

In [116]:
one_hot = pd.get_dummies(pd.Series(alphabet_codes)).astype(int).to_numpy()

In [128]:
labels_codes, labels = pd.Series([i[1] for i in img_samples]).factorize()

In [130]:
len(labels)

964

In [135]:
samples = [[img_path[2], alphabet, label] for img_path, alphabet, label in zip(img_samples, one_hot, labels_codes)]

In [136]:
samples[0]

['./data/omniglot-py/images_background/Japanese_(hiragana)/character05/0492_07.png',
 array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0]),
 0]

In [138]:
# custom Dataset class
class OmniglotDataset(Dataset):
    def __init__(self, transform, samples):
        self.transform = transform
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, alphabet, label = self.samples[idx]
        img = Image.open(img_path).convert('L')
        img = self.transform(img)
        return img, alphabet, label

In [139]:
dataset_train = OmniglotDataset(
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
    ]),
    samples=samples,
)

dataloader_train = DataLoader(
    dataset_train, shuffle=True, batch_size=3,
)

### Torch concatenation

In [17]:
x = torch.tensor([[1,2,3]])
y = torch.tensor([[4,5,6]])
torch.cat((x,y), dim=0)

tensor([[1, 2, 3],
        [4, 5, 6]])

In [18]:
torch.cat((x,y), dim=1)

tensor([[1, 2, 3, 4, 5, 6]])

# Model

Building two-input model: image of character and alphabet label as one-hot vector.

<img src="./img/multi_input.png" alt="multi_input" style="width: 600px;"/>

In [20]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()        
        self.image_layer = nn.Sequential(            
            nn.Conv2d(1, 16, kernel_size=3, padding=1),            
            nn.MaxPool2d(kernel_size=2),            
            nn.ELU(),            
            nn.Flatten(),            
            nn.Linear(16*32*32, 128)        
        )        
        self.alphabet_layer = nn.Sequential(            
            nn.Linear(30, 8),            
            nn.ELU(),        
        )        
        self.classifier = nn.Sequential(            
            nn.Linear(128 + 8, 964),         
        )
        
    def forward(self, x_image, x_alphabet):
        x_image = self.image_layer(x_image)
        x_alphabet = self.alphabet_layer(x_alphabet)
        # concatenate outputs
        x = torch.cat((x_image,x_alphabet), dim=1)
        return self.classifier(x)
        

## Training Loop

In [23]:
import torch.optim as optim

In [21]:
net = Net()

In [114]:
loss_function = nn.CrossEntropyLoss()

In [24]:
optimizer = optim.SGD(net.parameters(), lr=0.01)

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    for img, alph, labels in train_loader:
        optimizer.zero_grad()
        outputs = net(img, alph)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()