In [1]:
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from torch import nn
device = torch.device('cuda')

from recformer.utils import predict_on_batch, measure_accuracy
from recformer.dataset import Dataset, EmbDataset, pad_tensor, Padder
from recformer.train import train, train_multitask
from recformer.recformer import RecFormer, MiltitaskRecFormer


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_fwf('train.csv', header=None)
data = df[0].str.split(',', expand=True).values.tolist()
len(data)

23547

In [3]:
from collections import Counter
cusine_counter = Counter()
cusine_vocab = {}

for recipe in data:
  if None in recipe:
    recipe_length = recipe.index(None) - 1
  else:
    recipe_length = len(recipe) - 1
  cusine_counter[recipe[recipe_length]] +=1

#print(cusine_counter)

cusine_num = 0
for k, v in cusine_counter.items():
  if v > 100:
    cusine_vocab[k] = cusine_num
    cusine_num+=1

In [4]:
print(cusine_vocab)

{'greek': 0, 'filipino': 1, 'indian': 2, 'jamaican': 3, 'spanish': 4, 'italian': 5, 'mexican': 6, 'vietnamese': 7, 'thai': 8, 'southern_us': 9, 'chinese': 10, 'cajun_creole': 11, 'brazilian': 12, 'french': 13, 'japanese': 14, 'irish': 15, 'moroccan': 16, 'korean': 17, 'british': 18, 'russian': 19}


Derrick's reading

In [5]:
df = pd.read_csv('train.csv', engine= 'python', sep='\,',  names=list(range(61)))
df1 = df.fillna(0)
df_2 = df1.values.tolist()

In [6]:
# Separating the cuisines from the recipies 
train_ingredients = []
train_labels = []
for i, val in enumerate(df_2):
    R_l = [v for v in val if v !=0]
    train_ingredients.append(list(map(int, R_l[:-1]))) 
    train_labels.append(cusine_vocab[R_l[-1]])

In [7]:
df = pd.read_csv('validation_classification_question.csv', engine= 'python', sep='\,',  names=list(range(60)))
df1 = df.fillna(0)
val_x = df1.values.tolist()

df = pd.read_csv('validation_classification_answer.csv', engine= 'python', sep='\,',  names=list(range(60)))
df1 = df.fillna(0)
val_y = df1.values.tolist()

In [8]:
val_ingredients_c = []
val_labels_c = []

for i in range(len(val_x)):
  R_l = [v for v in val_x[i] if v !=0]
  val_ingredients_c.append(list(map(int, R_l[:-1]))) 
  val_labels_c.append(cusine_vocab[val_y[i][0]])

print(len(val_ingredients_c), len(val_labels_c))

7848 7848


In [9]:
batch_size = 64
epochs = 32

num_labels = len(cusine_vocab)
num_tokens = 6714
dim_model = 128
num_heads = 4
num_encoder_layers = 3
num_decoder_layers = 1
dropout_p = 0.3

In [11]:
from torch.utils.data import DataLoader

padder = Padder(dim=1, pad_symbol=-1)
train_dataset = Dataset(train_ingredients, train_labels)
validation_dataset = Dataset(val_ingredients_c, val_labels_c)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=True)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, collate_fn = padder)

In [12]:
model = RecFormer(num_tokens, num_labels, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
train(model, criterion, optimizer, train_loader, validation_loader, epochs, device=device)

In [None]:
#torch.save(model.state_dict(), "RecFormer_classification.pth")

In [13]:
model.load_state_dict(torch.load("weights/RecFormer_classification.pth"))
model.to(device)
print(measure_accuracy(model, train_loader))
print(measure_accuracy(model, validation_loader))

tensor(0.9040, device='cuda:0')
tensor(0.7429, device='cuda:0')


# Completion task

In [10]:
completion_data = []
completion_labels = []

for recipe in train_ingredients:
  ingredients_num = len(recipe)
  for i in range(ingredients_num):
    incomplete_recipe = recipe[:ingredients_num].copy()
    missing_ingredient = incomplete_recipe.pop(i)

    completion_data.append(incomplete_recipe)
    completion_labels.append(missing_ingredient)

print(len(completion_data), len(completion_labels))

253453 253453


In [11]:
df = pd.read_csv('validation_completion_question.csv', engine= 'python', sep='\,',  names=list(range(60)))
df1 = df.fillna(0)
val_x = df1.values.tolist()

df = pd.read_csv('validation_completion_answer.csv', engine= 'python', sep='\,',  names=list(range(60)))
df1 = df.fillna(0)
val_y = df1.values.tolist()

print(len(val_x), len(val_y))

7848 7848


In [12]:
val_ingredients = []
val_labels = []

for i in range(len(val_x)):
  R_l = [v for v in val_x[i] if v !=0]
  val_ingredients.append(list(map(int, R_l[:-1]))) 
  val_labels.append(int(val_y[i][0]))

print(len(val_ingredients), len(val_labels))

7848 7848


In [13]:
from torch.utils.data import DataLoader

padder = Padder(dim=1, pad_symbol=-1)
train_dataset = Dataset(completion_data, completion_labels)
validation_dataset = Dataset(val_ingredients, val_labels)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=True)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, collate_fn = padder)

In [18]:
model = RecFormer(num_tokens, num_tokens, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
train(model, criterion, optimizer, train_loader, validation_loader, epochs)

In [None]:
#torch.save(model.state_dict(), "RecFormer_completion.pth")

In [19]:
model.load_state_dict(torch.load("weights/RecFormer_completion.pth"))
model.to(device)

print(measure_accuracy(model, train_loader))
print(measure_accuracy(model, validation_loader))

tensor(0.1541, device='cuda:0')
tensor(0.1138, device='cuda:0')


# Mutlti-task experiments

In [13]:
multitask_data = []
multitask_labels = []


for recipe, cusine in zip(train_ingredients, train_labels):
  ingredients_num = len(recipe)
  for i in range(ingredients_num):
    incomplete_recipe = recipe[:ingredients_num].copy()
    missing_ingredient = incomplete_recipe.pop(i)

    multitask_data.append(incomplete_recipe)
    multitask_labels.append([cusine, missing_ingredient])
    
print(len(multitask_data), len(multitask_labels))

253453 253453


In [14]:
padder = Padder(dim=1, pad_symbol=-1)
train_dataset = Dataset(multitask_data, multitask_labels)
validation_dataset_cuisine = Dataset(val_ingredients_c, val_labels_c)
validation_dataset_ingredients = Dataset(val_ingredients, val_labels)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=True)
validation_loader_cuisine = DataLoader(dataset=validation_dataset_cuisine, batch_size=batch_size, collate_fn = padder)
validation_loader_ingredients = DataLoader(dataset=validation_dataset_ingredients, batch_size=batch_size, collate_fn = padder)

In [15]:
model = MiltitaskRecFormer(num_tokens, num_labels, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
train_multitask(model, criterion, optimizer, train_loader, validation_loader_cuisine, validation_loader_ingredients, epochs, loss_weights=[0.5, 2])

In [None]:
#torch.save(model.state_dict(), "RecFormer_multitask.pth")

In [16]:
model.load_state_dict(torch.load("weights/RecFormer_multitask.pth"))
model.to(device)
print(measure_accuracy(model, validation_loader_cuisine, multitask_switch="cuisine"))
print(measure_accuracy(model, validation_loader_ingredients, multitask_switch="ingredients"))

tensor(0.7608, device='cuda:0')
tensor(0.1244, device='cuda:0')


In [17]:
print(measure_accuracy(model, train_loader, multitask_switch="cuisine"))
print(measure_accuracy(model, train_loader, multitask_switch="ingredients"))

tensor(0.8828, device='cuda:0')
tensor(0.2002, device='cuda:0')


In [25]:
model.eval()
classification_preds = []
with torch.no_grad():
    for batch in validation_loader_cuisine:
        preds = predict_on_batch(model, batch, "cuisine")
        target = batch[1].to(device=device).flatten()
        classification_preds.extend(torch.argmax(preds, axis=1).tolist())
print(len(classification_preds))

7848


In [26]:
id_to_cus = {y: x for x, y in cusine_vocab.items()}

In [27]:
from sklearn.metrics import classification_report, f1_score, confusion_matrix

print(classification_report([id_to_cus[id] for id in val_labels_c], [id_to_cus[id] for id in classification_preds]))

              precision    recall  f1-score   support

   brazilian       0.70      0.61      0.65        85
     british       0.47      0.43      0.45       161
cajun_creole       0.76      0.66      0.71       295
     chinese       0.78      0.85      0.81       516
    filipino       0.71      0.60      0.65       141
      french       0.55      0.62      0.58       538
       greek       0.68      0.65      0.67       222
      indian       0.85      0.88      0.86       624
       irish       0.50      0.53      0.52       122
     italian       0.83      0.84      0.83      1558
    jamaican       0.69      0.61      0.65       113
    japanese       0.80      0.65      0.71       290
      korean       0.78      0.69      0.74       167
     mexican       0.90      0.90      0.90      1273
    moroccan       0.72      0.76      0.74       160
     russian       0.55      0.51      0.53        92
 southern_us       0.69      0.74      0.71       839
     spanish       0.51    

## Word embeddings multi-task experiment

In [None]:
!wget http://nlp.stanford.edu/data/glove.6B.zip
!unzip glove*.zip

In [14]:
import numpy as np

emb_length = 300
glove_vocab = {}
with open('glove.6B/glove.6B.{}d.txt'.format(emb_length), encoding='utf-8') as f:
  for line in f:
      values = line.split()
      word = values[0]
      coefs = np.asarray(values[1:], dtype='float32')
      glove_vocab[word] = coefs

In [15]:
PAD_embedding = torch.zeros(emb_length)
UNK_embedding = np.mean(list(glove_vocab.values()), axis=0)

In [16]:
df = pd.read_fwf('node_ingredient.csv', header=None)
node_ingredient = df[0].values.tolist()
print(len(node_ingredient))
ing_id_to_str = {i: ing for i,ing in enumerate(node_ingredient)}

6714


In [17]:
padder = Padder(dim=0, pad_symbol=PAD_embedding)

train_dataset = EmbDataset(multitask_data, multitask_labels, glove_vocab, UNK_embedding, ing_id_to_str)
validation_dataset_cuisine = EmbDataset(val_ingredients_c, val_labels_c, glove_vocab, UNK_embedding, ing_id_to_str)
validation_dataset_ingredients = EmbDataset(val_ingredients, val_labels, glove_vocab, UNK_embedding, ing_id_to_str)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=True)
validation_loader_cuisine = DataLoader(dataset=validation_dataset_cuisine, batch_size=batch_size, collate_fn = padder)
validation_loader_ingredients = DataLoader(dataset=validation_dataset_ingredients, batch_size=batch_size, collate_fn = padder)

In [18]:
model = MiltitaskRecFormer(num_tokens, num_labels, emb_length, num_heads, num_encoder_layers, num_decoder_layers, dropout_p, use_pretrained_embeddings=True)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [23]:
train_multitask(model, criterion, optimizer, train_loader, validation_loader_cuisine, validation_loader_ingredients, epochs, loss_weights=[0.5,2])

100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:07<00:00, 12.90it/s]


Epoch: 0, Training Loss: 13.92051126330346, Validation Loss Cuisine: 1.5136070425917463, Validation Loss Ingredients: 6.1752993382089505
Validation classification accuracy: 0.5465086698532104
Validation completion accuracy: 0.03669724613428116
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:08<00:00, 12.84it/s]


Epoch: 1, Training Loss: 12.345848593197818, Validation Loss Cuisine: 1.026041844026829, Validation Loss Ingredients: 5.688578853762246
Validation classification accuracy: 0.7029816508293152
Validation completion accuracy: 0.08600916713476181
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:13<00:00, 12.64it/s]


Epoch: 2, Training Loss: 11.468343423189944, Validation Loss Cuisine: 0.9545198901882016, Validation Loss Ingredients: 5.496529358189281
Validation classification accuracy: 0.73050457239151
Validation completion accuracy: 0.10244648158550262
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:13<00:00, 12.64it/s]


Epoch: 3, Training Loss: 11.065888933086901, Validation Loss Cuisine: 0.8963074502421589, Validation Loss Ingredients: 5.4145425974838135
Validation classification accuracy: 0.7488532066345215
Validation completion accuracy: 0.10996431857347488
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:20<00:00, 12.37it/s]


Epoch: 4, Training Loss: 10.81814929314016, Validation Loss Cuisine: 0.9113690751354869, Validation Loss Ingredients: 5.37940559542276
Validation classification accuracy: 0.7496176958084106
Validation completion accuracy: 0.11760957539081573
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:16<00:00, 12.52it/s]


Epoch: 5, Training Loss: 10.648928140034974, Validation Loss Cuisine: 0.891353951237066, Validation Loss Ingredients: 5.345465861684907
Validation classification accuracy: 0.7517838478088379
Validation completion accuracy: 0.12117736786603928
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:16<00:00, 12.51it/s]


Epoch: 6, Training Loss: 10.519253636152868, Validation Loss Cuisine: 0.878387838602066, Validation Loss Ingredients: 5.332745257431899
Validation classification accuracy: 0.7557339072227478
Validation completion accuracy: 0.12041284143924713


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:17<00:00, 12.46it/s]


Epoch: 7, Training Loss: 10.417542762510767, Validation Loss Cuisine: 0.8736150058788982, Validation Loss Ingredients: 5.343840114469451
Validation classification accuracy: 0.7575178146362305
Validation completion accuracy: 0.1179918423295021


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:16<00:00, 12.50it/s]


Epoch: 8, Training Loss: 10.31701119636472, Validation Loss Cuisine: 0.8708018500630449, Validation Loss Ingredients: 5.347934505803798
Validation classification accuracy: 0.7604485154151917
Validation completion accuracy: 0.12143220752477646
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:18<00:00, 12.44it/s]


Epoch: 9, Training Loss: 10.248327545131346, Validation Loss Cuisine: 0.8608376189945189, Validation Loss Ingredients: 5.348286799299038
Validation classification accuracy: 0.7605758905410767
Validation completion accuracy: 0.12436288595199585
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:16<00:00, 12.53it/s]


Epoch: 10, Training Loss: 10.179545236518665, Validation Loss Cuisine: 0.865738942128856, Validation Loss Ingredients: 5.321166306007199
Validation classification accuracy: 0.7617226839065552
Validation completion accuracy: 0.12296126037836075


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:19<00:00, 12.39it/s]


Epoch: 11, Training Loss: 10.122222281501498, Validation Loss Cuisine: 0.8647640851455006, Validation Loss Ingredients: 5.334899875206676
Validation classification accuracy: 0.7636340260505676
Validation completion accuracy: 0.12614677846431732
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:16<00:00, 12.50it/s]


Epoch: 12, Training Loss: 10.07096917145904, Validation Loss Cuisine: 0.8705057889465394, Validation Loss Ingredients: 5.338251435659765
Validation classification accuracy: 0.7628694772720337
Validation completion accuracy: 0.12296126037836075


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:14<00:00, 12.58it/s]


Epoch: 13, Training Loss: 10.01758675790742, Validation Loss Cuisine: 0.8603105663768644, Validation Loss Ingredients: 5.320168906111058
Validation classification accuracy: 0.7652904987335205
Validation completion accuracy: 0.12308868020772934


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:13<00:00, 12.63it/s]


Epoch: 14, Training Loss: 9.978060943374306, Validation Loss Cuisine: 0.8719300450832863, Validation Loss Ingredients: 5.328543852984421
Validation classification accuracy: 0.7590468525886536
Validation completion accuracy: 0.1265290528535843


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:13<00:00, 12.65it/s]


Epoch: 15, Training Loss: 9.93521052976655, Validation Loss Cuisine: 0.8613416024339877, Validation Loss Ingredients: 5.330942444685029
Validation classification accuracy: 0.7640162706375122
Validation completion accuracy: 0.12576451897621155


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:14<00:00, 12.59it/s]


Epoch: 16, Training Loss: 9.895238581765513, Validation Loss Cuisine: 0.8694959321642309, Validation Loss Ingredients: 5.344790342377453
Validation classification accuracy: 0.7617226839065552
Validation completion accuracy: 0.12589193880558014


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:14<00:00, 12.59it/s]


Epoch: 17, Training Loss: 9.865817737290426, Validation Loss Cuisine: 0.8676860155613442, Validation Loss Ingredients: 5.357095226039731
Validation classification accuracy: 0.7626146674156189
Validation completion accuracy: 0.12449031323194504


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:15<00:00, 12.57it/s]


Epoch: 18, Training Loss: 9.848148062445244, Validation Loss Cuisine: 0.8740922857106217, Validation Loss Ingredients: 5.321006623710074
Validation classification accuracy: 0.7595565319061279
Validation completion accuracy: 0.12423546612262726


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:14<00:00, 12.59it/s]


Epoch: 19, Training Loss: 9.822962659562787, Validation Loss Cuisine: 0.8689897670978453, Validation Loss Ingredients: 5.344538413412202
Validation classification accuracy: 0.7626146674156189
Validation completion accuracy: 0.12869520485401154


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:15<00:00, 12.57it/s]


Epoch: 20, Training Loss: 9.790589877674899, Validation Loss Cuisine: 0.8683346618966359, Validation Loss Ingredients: 5.355940113222696
Validation classification accuracy: 0.7656727433204651
Validation completion accuracy: 0.1264016330242157
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:13<00:00, 12.64it/s]


Epoch: 21, Training Loss: 9.766331337281352, Validation Loss Cuisine: 0.8765718936920166, Validation Loss Ingredients: 5.394569679973571
Validation classification accuracy: 0.764143705368042
Validation completion accuracy: 0.12703873217105865


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:13<00:00, 12.63it/s]


Epoch: 22, Training Loss: 9.738021863105772, Validation Loss Cuisine: 0.8641276391056495, Validation Loss Ingredients: 5.366755904220954
Validation classification accuracy: 0.7640162706375122
Validation completion accuracy: 0.12754841148853302


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:12<00:00, 12.66it/s]


Epoch: 23, Training Loss: 9.72851652462234, Validation Loss Cuisine: 0.8638701954992806, Validation Loss Ingredients: 5.335395561001165
Validation classification accuracy: 0.7650356292724609
Validation completion accuracy: 0.12194189429283142


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:13<00:00, 12.62it/s]


Epoch: 24, Training Loss: 9.695678875380711, Validation Loss Cuisine: 0.8791335782384485, Validation Loss Ingredients: 5.362986452211209
Validation classification accuracy: 0.7666921019554138
Validation completion accuracy: 0.12563709914684296


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:12<00:00, 12.66it/s]


Epoch: 25, Training Loss: 9.686133372907053, Validation Loss Cuisine: 0.8662978946193447, Validation Loss Ingredients: 5.3912596043532455
Validation classification accuracy: 0.7688582539558411
Validation completion accuracy: 0.1249999925494194


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:14<00:00, 12.61it/s]


Epoch: 26, Training Loss: 9.661829697065539, Validation Loss Cuisine: 0.8709461827103685, Validation Loss Ingredients: 5.394427477828855
Validation classification accuracy: 0.7672017812728882
Validation completion accuracy: 0.12181446701288223


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:16<00:00, 12.53it/s]


Epoch: 27, Training Loss: 9.642868107480552, Validation Loss Cuisine: 0.8592591385046641, Validation Loss Ingredients: 5.360887368520101
Validation classification accuracy: 0.7663098573684692
Validation completion accuracy: 0.12703873217105865
Best model saved


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:26<00:00, 12.12it/s]


Epoch: 28, Training Loss: 9.627448730714331, Validation Loss Cuisine: 0.860310306151708, Validation Loss Ingredients: 5.3534478831097365
Validation classification accuracy: 0.7632517218589783
Validation completion accuracy: 0.12487257272005081


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:20<00:00, 12.35it/s]


Epoch: 29, Training Loss: 9.621944503813072, Validation Loss Cuisine: 0.8734334305049928, Validation Loss Ingredients: 5.364220227652449
Validation classification accuracy: 0.761850118637085
Validation completion accuracy: 0.12461773306131363


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:12<00:00, 12.67it/s]


Epoch: 30, Training Loss: 9.600025518648975, Validation Loss Cuisine: 0.8635326008486554, Validation Loss Ingredients: 5.382072708471035
Validation classification accuracy: 0.7637614607810974
Validation completion accuracy: 0.12614677846431732


100%|██████████████████████████████████████████████████████████████████████████████| 3961/3961 [05:12<00:00, 12.66it/s]


Epoch: 31, Training Loss: 9.584971200754714, Validation Loss Cuisine: 0.8720227569583955, Validation Loss Ingredients: 5.401036855651111
Validation classification accuracy: 0.7643985152244568
Validation completion accuracy: 0.12194189429283142


In [19]:
model.load_state_dict(torch.load("weights/RecFormer_multitask_emb.pth"))
model.to(device)
print(measure_accuracy(model, validation_loader_cuisine, multitask_switch="cuisine"))
print(measure_accuracy(model, validation_loader_ingredients, multitask_switch="ingredients"))

tensor(0.7663, device='cuda:0')
tensor(0.1270, device='cuda:0')


In [20]:
print(measure_accuracy(model, train_loader, multitask_switch="cuisine"))
print(measure_accuracy(model, train_loader, multitask_switch="ingredients"))

tensor(0.8516, device='cuda:0')
tensor(0.2076, device='cuda:0')


In [21]:
id_to_cus = {y: x for x, y in cusine_vocab.items()}

In [22]:
model.eval()
classification_preds = []
with torch.no_grad():
    for batch in validation_loader_cuisine:
        preds = predict_on_batch(model, batch, "cuisine")
        target = batch[1].to(device=device).flatten()
        classification_preds.extend(torch.argmax(preds, axis=1).tolist())
print(len(classification_preds))

7848


In [23]:
from sklearn.metrics import classification_report, f1_score, confusion_matrix

print(classification_report([id_to_cus[id] for id in val_labels_c], [id_to_cus[id] for id in classification_preds]))

              precision    recall  f1-score   support

   brazilian       0.68      0.54      0.60        85
     british       0.50      0.48      0.49       161
cajun_creole       0.73      0.71      0.72       295
     chinese       0.78      0.82      0.80       516
    filipino       0.63      0.64      0.63       141
      french       0.58      0.58      0.58       538
       greek       0.71      0.66      0.69       222
      indian       0.86      0.88      0.87       624
       irish       0.53      0.54      0.54       122
     italian       0.81      0.86      0.83      1558
    jamaican       0.76      0.66      0.71       113
    japanese       0.76      0.67      0.71       290
      korean       0.76      0.72      0.74       167
     mexican       0.90      0.91      0.90      1273
    moroccan       0.80      0.72      0.76       160
     russian       0.59      0.50      0.54        92
 southern_us       0.70      0.75      0.72       839
     spanish       0.60    

In [31]:
model.eval()
classification_preds = []
with torch.no_grad():
    for batch in DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn = padder, shuffle=False):
        preds = predict_on_batch(model, batch, "cuisine")
        target = batch[1].to(device=device).flatten()
        classification_preds.extend(torch.argmax(preds, axis=1).tolist())
print(len(classification_preds))

253453


In [32]:
print(classification_report([id_to_cus[id[0]] for id in multitask_labels], [id_to_cus[id] for id in classification_preds]))

              precision    recall  f1-score   support

   brazilian       0.82      0.63      0.71      2768
     british       0.61      0.65      0.63      4632
cajun_creole       0.81      0.84      0.82     11565
     chinese       0.88      0.89      0.88     19120
    filipino       0.71      0.75      0.73      4468
      french       0.66      0.73      0.69     15205
       greek       0.79      0.82      0.80      7330
      indian       0.91      0.93      0.92     22384
       irish       0.71      0.63      0.67      3649
     italian       0.87      0.88      0.88     46378
    jamaican       0.85      0.89      0.87      3471
    japanese       0.85      0.74      0.79      8202
      korean       0.84      0.89      0.86      5297
     mexican       0.95      0.93      0.94     41581
    moroccan       0.83      0.91      0.87      6466
     russian       0.73      0.59      0.65      3025
 southern_us       0.81      0.78      0.79     24268
     spanish       0.69    