In [1]:
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from torch import nn
device = torch.device('cuda')

from utils import predict_on_batch, measure_accuracy
from dataset import Dataset, EmbDataset, pad_tensor, Padder
from train import train, train_multitask
from recformer import Transformer, MiltitaskTransformer


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_fwf('train.csv', header=None)
data = df[0].str.split(',', expand=True).values.tolist()
len(data)

23547

In [3]:
from collections import Counter
cusine_counter = Counter()
cusine_vocab = {}

for recipe in data:
  if None in recipe:
    recipe_length = recipe.index(None) - 1
  else:
    recipe_length = len(recipe) - 1
  cusine_counter[recipe[recipe_length]] +=1

#print(cusine_counter)

cusine_num = 0
for k, v in cusine_counter.items():
  if v > 100:
    cusine_vocab[k] = cusine_num
    cusine_num+=1

Derrick's reading

In [4]:
df = pd.read_csv('train.csv', engine= 'python', sep='\,',  names=list(range(61)))
df1 = df.fillna(0)
df_2 = df1.values.tolist()

In [5]:
# Separating the cuisines from the recipies 
train_ingredients = []
train_labels = []
for i, val in enumerate(df_2):
    R_l = [v for v in val if v !=0]
    train_ingredients.append(list(map(int, R_l[:-1]))) 
    train_labels.append(cusine_vocab[R_l[-1]])

In [6]:
df = pd.read_csv('validation_classification_question.csv', engine= 'python', sep='\,',  names=list(range(60)))
df1 = df.fillna(0)
val_x = df1.values.tolist()

df = pd.read_csv('validation_classification_answer.csv', engine= 'python', sep='\,',  names=list(range(60)))
df1 = df.fillna(0)
val_y = df1.values.tolist()

In [7]:
val_ingredients_c = []
val_labels_c = []

for i in range(len(val_x)):
  R_l = [v for v in val_x[i] if v !=0]
  val_ingredients_c.append(list(map(int, R_l[:-1]))) 
  val_labels_c.append(cusine_vocab[val_y[i][0]])

print(len(val_ingredients_c), len(val_labels_c))

7848 7848


In [8]:
batch_size = 64
epochs = 32

num_labels = len(cusine_vocab)
num_tokens = 6714
dim_model = 128
num_heads = 4
num_encoder_layers = 3
num_decoder_layers = 1
dropout_p = 0.3

In [34]:
from torch.utils.data import DataLoader

padder = Padder(dim=1, pad_symbol=-1)
train_dataset = Dataset(train_ingredients, train_labels)
validation_dataset = Dataset(val_ingredients_c, val_labels_c)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=True)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, collate_fn = padder)

In [10]:
model = Transformer(num_tokens, num_labels, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [11]:
train(model, criterion, optimizer, train_loader, validation_loader, epochs, device=device)

 56%|█████████████████████████████████████████████                                   | 207/368 [00:08<00:06, 25.53it/s]


KeyboardInterrupt: 

In [457]:
#torch.save(model.state_dict(), "RecFormer_classification.pth")

In [12]:
model.load_state_dict(torch.load("RecFormer_classification.pth"))
model.to(device)
print(measure_accuracy(model, train_loader))
print(measure_accuracy(model, validation_loader))

tensor(0.9047, device='cuda:0')
tensor(0.7429, device='cuda:0')


# Completion task

In [14]:
completion_data = []
completion_labels = []

for recipe in train_ingredients:
  ingredients_num = len(recipe)
  for i in range(ingredients_num):
    incomplete_recipe = recipe[:ingredients_num].copy()
    missing_ingredient = incomplete_recipe.pop(i)

    completion_data.append(incomplete_recipe)
    completion_labels.append(missing_ingredient)

print(len(completion_data), len(completion_labels))

253453 253453


In [15]:
max_len = max(completion_data, key = lambda x: len(x))
print(len(max_len))

58


In [16]:
df = pd.read_csv('validation_completion_question.csv', engine= 'python', sep='\,',  names=list(range(60)))
df1 = df.fillna(0)
val_x = df1.values.tolist()

df = pd.read_csv('validation_completion_answer.csv', engine= 'python', sep='\,',  names=list(range(60)))
df1 = df.fillna(0)
val_y = df1.values.tolist()

print(len(val_x), len(val_y))

7848 7848


In [17]:
val_ingredients = []
val_labels = []

for i in range(len(val_x)):
  R_l = [v for v in val_x[i] if v !=0]
  val_ingredients.append(list(map(int, R_l[:-1]))) 
  val_labels.append(int(val_y[i][0]))

print(len(val_ingredients), len(val_labels))

7848 7848


In [18]:
from torch.utils.data import DataLoader

padder = Padder(dim=1, pad_symbol=-1)
train_dataset = Dataset(completion_data, completion_labels)
validation_dataset = Dataset(val_ingredients, val_labels)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=True)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, collate_fn = padder)

In [19]:
model = Transformer(num_tokens, num_tokens, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [507]:
train(model, criterion, optimizer, train_loader, validation_loader, epochs)

100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.08it/s]


Epoch: 0, Training Loss: 6.517052804662314, Validation Loss: 6.4315637728063075
tensor(0.0401, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:25<00:00, 27.14it/s]


Epoch: 1, Training Loss: 6.324401523572753, Validation Loss: 6.235665728406208
tensor(0.0404, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:28<00:00, 26.62it/s]


Epoch: 2, Training Loss: 6.168433147137648, Validation Loss: 6.129349650406256
tensor(0.0505, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:28<00:00, 26.67it/s]


Epoch: 3, Training Loss: 6.054775495416015, Validation Loss: 6.038973052327226
tensor(0.0552, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:30<00:00, 26.25it/s]


Epoch: 4, Training Loss: 5.98998115520771, Validation Loss: 6.01111094932246
tensor(0.0566, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:31<00:00, 26.12it/s]


Epoch: 5, Training Loss: 5.938325452551761, Validation Loss: 5.9401054033418985
tensor(0.0604, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:31<00:00, 26.08it/s]


Epoch: 6, Training Loss: 5.887189463153388, Validation Loss: 5.894096583854862
tensor(0.0672, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:28<00:00, 26.61it/s]


Epoch: 7, Training Loss: 5.834484847318461, Validation Loss: 5.900716227244556
tensor(0.0717, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:28<00:00, 26.62it/s]


Epoch: 8, Training Loss: 5.789391932479061, Validation Loss: 5.856374360681549
tensor(0.0736, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:32<00:00, 25.95it/s]


Epoch: 9, Training Loss: 5.752725931166158, Validation Loss: 5.830272205476838
tensor(0.0767, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:29<00:00, 26.46it/s]


Epoch: 10, Training Loss: 5.717971167820808, Validation Loss: 5.842933616017907
tensor(0.0765, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 26.99it/s]


Epoch: 11, Training Loss: 5.681107663711012, Validation Loss: 5.792962190581531
tensor(0.0794, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.09it/s]


Epoch: 12, Training Loss: 5.641954576207724, Validation Loss: 5.734560493531266
tensor(0.0821, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:27<00:00, 26.93it/s]


Epoch: 13, Training Loss: 5.59894215028354, Validation Loss: 5.669055469636994
tensor(0.0831, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 26.95it/s]


Epoch: 14, Training Loss: 5.5551937014669575, Validation Loss: 5.659804576780738
tensor(0.0886, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.01it/s]


Epoch: 15, Training Loss: 5.511983973183616, Validation Loss: 5.646342409335501
tensor(0.0943, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.05it/s]


Epoch: 16, Training Loss: 5.4714249796531504, Validation Loss: 5.604461518729606
tensor(0.0924, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:27<00:00, 26.83it/s]


Epoch: 17, Training Loss: 5.434056706032468, Validation Loss: 5.594846558764698
tensor(0.0996, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:27<00:00, 26.90it/s]


Epoch: 18, Training Loss: 5.401743520482927, Validation Loss: 5.568623876183983
tensor(0.1009, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:27<00:00, 26.93it/s]


Epoch: 19, Training Loss: 5.373979302314579, Validation Loss: 5.543551964488456
tensor(0.1016, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.00it/s]


Epoch: 20, Training Loss: 5.347286399642917, Validation Loss: 5.497306629894226
tensor(0.1060, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.01it/s]


Epoch: 21, Training Loss: 5.326656092600641, Validation Loss: 5.515064510872693
tensor(0.1074, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.05it/s]


Epoch: 22, Training Loss: 5.303040287164208, Validation Loss: 5.509409272573827
tensor(0.1083, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.02it/s]


Epoch: 23, Training Loss: 5.283109058186551, Validation Loss: 5.529116347553284
tensor(0.1086, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.04it/s]


Epoch: 24, Training Loss: 5.263341152977985, Validation Loss: 5.483751576121261
tensor(0.1058, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.01it/s]


Epoch: 25, Training Loss: 5.246926651732905, Validation Loss: 5.496711122311227
tensor(0.1107, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:27<00:00, 26.92it/s]


Epoch: 26, Training Loss: 5.228971532245262, Validation Loss: 5.4721035181991455
tensor(0.1074, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 26.99it/s]


Epoch: 27, Training Loss: 5.214912814463187, Validation Loss: 5.4714219085569304
tensor(0.1121, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:27<00:00, 26.94it/s]


Epoch: 28, Training Loss: 5.200572883622329, Validation Loss: 5.469997580458478
tensor(0.1133, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 26.95it/s]


Epoch: 29, Training Loss: 5.1875198564575165, Validation Loss: 5.44681383536114
tensor(0.1112, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 26.96it/s]


Epoch: 30, Training Loss: 5.177347969449069, Validation Loss: 5.420845543465963
tensor(0.1130, device='cuda:0')


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:26<00:00, 27.04it/s]


Epoch: 31, Training Loss: 5.163116875696411, Validation Loss: 5.448413976808873
tensor(0.1138, device='cuda:0')


In [508]:
#torch.save(model.state_dict(), "RecFormer_completion.pth")

In [20]:
model.load_state_dict(torch.load("RecFormer_completion.pth"))
model.to(device)

print(measure_accuracy(model, train_loader))
print(measure_accuracy(model, validation_loader))

tensor(0.1535, device='cuda:0')
tensor(0.1138, device='cuda:0')


# Mutlti-task experiments

In [21]:
multitask_data = []
multitask_labels = []


for recipe, cusine in zip(train_ingredients, train_labels):
  ingredients_num = len(recipe)
  for i in range(ingredients_num):
    incomplete_recipe = recipe[:ingredients_num].copy()
    #print(recipe, ingredients_num)
    missing_ingredient = incomplete_recipe.pop(i)
    #incomplete_recipe.insert(0, 6714 + cusine_vocab[recipe[ingredients_num]])

    multitask_data.append(incomplete_recipe)
    multitask_labels.append([cusine, missing_ingredient])
    
print(len(multitask_data), len(multitask_labels))

253453 253453


In [22]:
padder = Padder(dim=1, pad_symbol=-1)
train_dataset = Dataset(multitask_data, multitask_labels)
validation_dataset_cuisine = Dataset(val_ingredients_c, val_labels_c)
validation_dataset_ingredients = Dataset(val_ingredients, val_labels)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=True)
validation_loader_cuisine = DataLoader(dataset=validation_dataset_cuisine, batch_size=batch_size, collate_fn = padder)
validation_loader_ingredients = DataLoader(dataset=validation_dataset_ingredients, batch_size=batch_size, collate_fn = padder)

In [23]:
model = MiltitaskTransformer(num_tokens, num_labels, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [553]:
train_multitask(model, criterion, optimizer, train_loader, validation_loader_cuisine, validation_loader_ingredients, epochs, loss_weights=[0.5, 2])

100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:33<00:00, 25.87it/s]


Epoch: 0, Training Loss: 10.466503519868645, Validation Loss Cuisine: 0.9088633939987276, Validation Loss Ingredients: 5.370745445654644
Validation classification accuracy: 0.7566258311271667
Validation completion accuracy: 0.1209225207567215
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:33<00:00, 25.83it/s]


Epoch: 1, Training Loss: 10.392543811745128, Validation Loss Cuisine: 0.9113027982837786, Validation Loss Ingredients: 5.343912008332043
Validation classification accuracy: 0.7571355700492859
Validation completion accuracy: 0.12028542160987854


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:35<00:00, 25.54it/s]


Epoch: 2, Training Loss: 10.342701013387378, Validation Loss Cuisine: 0.9087610543985677, Validation Loss Ingredients: 5.3287865049470735
Validation classification accuracy: 0.7552242279052734
Validation completion accuracy: 0.11964831501245499


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:33<00:00, 25.81it/s]


Epoch: 3, Training Loss: 10.31235755317532, Validation Loss Cuisine: 0.9158440182606379, Validation Loss Ingredients: 5.351227081888091
Validation classification accuracy: 0.757900059223175
Validation completion accuracy: 0.12296126037836075
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:33<00:00, 25.75it/s]


Epoch: 4, Training Loss: 10.28036433920298, Validation Loss Cuisine: 0.9144579639764336, Validation Loss Ingredients: 5.310914287722207
Validation classification accuracy: 0.7572629451751709
Validation completion accuracy: 0.12104994058609009


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:35<00:00, 25.44it/s]


Epoch: 5, Training Loss: 10.254644072016692, Validation Loss Cuisine: 0.9135345256910091, Validation Loss Ingredients: 5.331077560176694
Validation classification accuracy: 0.7593017220497131
Validation completion accuracy: 0.1179918423295021


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:33<00:00, 25.88it/s]


Epoch: 6, Training Loss: 10.228397238889567, Validation Loss Cuisine: 0.9157689993943625, Validation Loss Ingredients: 5.335360728628267
Validation classification accuracy: 0.7564984560012817
Validation completion accuracy: 0.12270641326904297


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:37<00:00, 25.11it/s]


Epoch: 7, Training Loss: 10.20786089874283, Validation Loss Cuisine: 0.9098929203138119, Validation Loss Ingredients: 5.332745269062073
Validation classification accuracy: 0.7600662112236023
Validation completion accuracy: 0.12321610003709793
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:34<00:00, 25.69it/s]


Epoch: 8, Training Loss: 10.189754524606553, Validation Loss Cuisine: 0.9131830763526079, Validation Loss Ingredients: 5.3616586157946085
Validation classification accuracy: 0.7594290971755981
Validation completion accuracy: 0.12181446701288223


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:36<00:00, 25.26it/s]


Epoch: 9, Training Loss: 10.178535730602222, Validation Loss Cuisine: 0.9072680132902735, Validation Loss Ingredients: 5.3367896002482595
Validation classification accuracy: 0.7601936459541321
Validation completion accuracy: 0.12206931412220001


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:33<00:00, 25.82it/s]


Epoch: 10, Training Loss: 10.166918260044174, Validation Loss Cuisine: 0.9132259386341747, Validation Loss Ingredients: 5.327346642812093
Validation classification accuracy: 0.7613404393196106
Validation completion accuracy: 0.12066768109798431


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:32<00:00, 26.00it/s]


Epoch: 11, Training Loss: 10.144780685311163, Validation Loss Cuisine: 0.9199899932960185, Validation Loss Ingredients: 5.336091669594369
Validation classification accuracy: 0.7552242279052734
Validation completion accuracy: 0.1237257868051529


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:32<00:00, 25.92it/s]


Epoch: 12, Training Loss: 10.144432779094, Validation Loss Cuisine: 0.9125344101491013, Validation Loss Ingredients: 5.331351605857291
Validation classification accuracy: 0.7570081353187561
Validation completion accuracy: 0.12130478769540787


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:33<00:00, 25.82it/s]


Epoch: 13, Training Loss: 10.122241872995978, Validation Loss Cuisine: 0.9150942259930014, Validation Loss Ingredients: 5.3429514954729775
Validation classification accuracy: 0.7580274939537048
Validation completion accuracy: 0.12385320663452148


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:34<00:00, 25.58it/s]


Epoch: 14, Training Loss: 10.114115823612343, Validation Loss Cuisine: 0.9189781134690695, Validation Loss Ingredients: 5.352570774109383
Validation classification accuracy: 0.7591742873191833
Validation completion accuracy: 0.12206931412220001


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:33<00:00, 25.78it/s]


Epoch: 15, Training Loss: 10.100394081268849, Validation Loss Cuisine: 0.9168576283183524, Validation Loss Ingredients: 5.34590309809863
Validation classification accuracy: 0.7552242279052734
Validation completion accuracy: 0.12296126037836075


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:33<00:00, 25.77it/s]


Epoch: 16, Training Loss: 10.094328881513889, Validation Loss Cuisine: 0.9125438578487411, Validation Loss Ingredients: 5.34003535324965
Validation classification accuracy: 0.7591742873191833
Validation completion accuracy: 0.1221967339515686


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:36<00:00, 25.24it/s]


Epoch: 17, Training Loss: 10.08704171677907, Validation Loss Cuisine: 0.9142723044728845, Validation Loss Ingredients: 5.3401793076740045
Validation classification accuracy: 0.7608307600021362
Validation completion accuracy: 0.12436288595199585
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:34<00:00, 25.63it/s]


Epoch: 18, Training Loss: 10.074923640999701, Validation Loss Cuisine: 0.9187797469094516, Validation Loss Ingredients: 5.338139313023265
Validation classification accuracy: 0.7591742873191833
Validation completion accuracy: 0.12385320663452148


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [02:36<00:00, 25.33it/s]


Epoch: 19, Training Loss: 10.065065961268116, Validation Loss Cuisine: 0.9190681135751367, Validation Loss Ingredients: 5.357593625541625
Validation classification accuracy: 0.7573903799057007
Validation completion accuracy: 0.12181446701288223


  4%|███▍                                                                           | 175/3961 [00:06<02:28, 25.46it/s]


KeyboardInterrupt: 

In [526]:
#torch.save(model.state_dict(), "RecFormer_multitask.pth")

In [35]:
model.load_state_dict(torch.load("RecFormer_multitask.pth"))
model.to(device)
print(measure_accuracy(model, validation_loader_cuisine, multitask_switch="cuisine"))
print(measure_accuracy(model, validation_loader_ingredients, multitask_switch="ingredients"))

tensor(0.7608, device='cuda:0')
tensor(0.1244, device='cuda:0')


In [43]:
model.eval()
classification_preds = []
with torch.no_grad():
    for batch in DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=False):
        preds = predict_on_batch(model, batch, "cuisine")
        target = batch[1].to(device=device).flatten()
        classification_preds.extend(torch.argmax(preds, axis=1).tolist())
print(len(classification_preds))

23547


In [44]:
id_to_cus = {y: x for x, y in cusine_vocab.items()}

In [45]:
from sklearn.metrics import classification_report, f1_score, confusion_matrix

print(classification_report([id_to_cus[id] for id in train_labels], [id_to_cus[id] for id in classification_preds]))

              precision    recall  f1-score   support

   brazilian       0.86      0.74      0.80       283
     british       0.77      0.69      0.73       485
cajun_creole       0.88      0.86      0.87       920
     chinese       0.89      0.91      0.90      1599
    filipino       0.86      0.78      0.81       452
      french       0.72      0.80      0.76      1543
       greek       0.84      0.85      0.85       714
      indian       0.91      0.95      0.93      1748
       irish       0.80      0.78      0.79       404
     italian       0.90      0.92      0.91      4678
    jamaican       0.91      0.89      0.90       280
    japanese       0.91      0.81      0.86       840
      korean       0.88      0.91      0.89       474
     mexican       0.95      0.95      0.95      3836
    moroccan       0.89      0.88      0.88       496
     russian       0.79      0.70      0.74       300
 southern_us       0.84      0.87      0.86      2515
     spanish       0.81    

## Word embeddings multi-task experiment

In [None]:
!wget http://nlp.stanford.edu/data/glove.6B.zip
!unzip glove*.zip

In [None]:
import numpy as np

emb_length = 100
glove_vocab = {}
with open('glove.6B/glove.6B.{}d.txt'.format(emb_length), encoding='utf-8') as f:
  for line in f:
      values = line.split()
      word = values[0]
      coefs = np.asarray(values[1:], dtype='float32')
      glove_vocab[word] = coefs

In [None]:
PAD_embedding = torch.zeros(emb_length)
UNK_embedding = np.mean(list(glove_vocab.values()), axis=0)

In [None]:
df = pd.read_fwf('node_ingredient.csv', header=None)
node_ingredient = df[0].values.tolist()
print(len(node_ingredient))
ing_id_to_str = {i: ing for i,ing in enumerate(node_ingredient)}

In [428]:
padder = Padder(dim=0, pad_symbol=PAD_embedding)

train_dataset = EmbDataset(multitask_data, multitask_labels, glove_vocab, UNK_embedding, ing_id_to_str)
validation_dataset_cuisine = EmbDataset(val_ingredients_c, val_labels_c, glove_vocab, UNK_embedding, ing_id_to_str)
validation_dataset_ingredients = EmbDataset(val_ingredients, val_labels, glove_vocab, UNK_embedding, ing_id_to_str)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=True)
validation_loader_cuisine = DataLoader(dataset=validation_dataset_cuisine, batch_size=batch_size, collate_fn = padder)
validation_loader_ingredients = DataLoader(dataset=validation_dataset_ingredients, batch_size=batch_size, collate_fn = padder)

In [429]:
model = MiltitaskTransformer(num_tokens, num_labels, 100, num_heads, num_encoder_layers, num_decoder_layers, dropout_p, use_pretrained_embeddings=True)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [430]:
train_multitask(model, criterion, optimizer, train_loader, validation_loader_cuisine, validation_loader_ingredients, epochs, loss_weights=[1,2])

100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [03:16<00:00, 20.19it/s]


Epoch: 0, Training Loss: 13.698796818816762, Validation Loss Cuisine: 1.0543987414216607, Validation Loss Ingredients: 5.829817884336642
Validation classification accuracy: 0.6845055818557739
Validation completion accuracy: 0.07097349315881729


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [03:15<00:00, 20.24it/s]


Epoch: 1, Training Loss: 12.384629779739592, Validation Loss Cuisine: 0.9593833212445422, Validation Loss Ingredients: 5.664339709088085
Validation classification accuracy: 0.7237512469291687
Validation completion accuracy: 0.08486238121986389


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [03:18<00:00, 20.00it/s]


Epoch: 2, Training Loss: 11.977050300680968, Validation Loss Cuisine: 0.9350827052825834, Validation Loss Ingredients: 5.542130543933651
Validation classification accuracy: 0.7287206649780273
Validation completion accuracy: 0.09212537854909897


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [03:17<00:00, 20.05it/s]


Epoch: 3, Training Loss: 11.756381512049382, Validation Loss Cuisine: 0.9571074867636208, Validation Loss Ingredients: 5.6684015359335795
Validation classification accuracy: 0.7320336103439331
Validation completion accuracy: 0.09110601246356964


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [03:18<00:00, 19.97it/s]


Epoch: 4, Training Loss: 11.610014121419882, Validation Loss Cuisine: 1.0390520670065067, Validation Loss Ingredients: 5.6687494293461
Validation classification accuracy: 0.7121559381484985
Validation completion accuracy: 0.08830274641513824


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [03:21<00:00, 19.65it/s]


Epoch: 5, Training Loss: 11.500401195448115, Validation Loss Cuisine: 1.1200186537533272, Validation Loss Ingredients: 5.863389790542727
Validation classification accuracy: 0.6980121731758118
Validation completion accuracy: 0.08103974908590317


  4%|███                                                                            | 154/3961 [00:07<03:11, 19.89it/s]


KeyboardInterrupt: 