<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Imports-&amp;-Inits" data-toc-modified-id="Imports-&amp;-Inits-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Imports &amp; Inits</a></span></li><li><span><a href="#Data-Preparation" data-toc-modified-id="Data-Preparation-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Data Preparation</a></span></li><li><span><a href="#Model" data-toc-modified-id="Model-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Model</a></span></li><li><span><a href="#Training" data-toc-modified-id="Training-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Training</a></span></li><li><span><a href="#Testing" data-toc-modified-id="Testing-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Testing</a></span><ul class="toc-item"><li><span><a href="#Ignite-Testing" data-toc-modified-id="Ignite-Testing-5.1"><span class="toc-item-num">5.1&nbsp;&nbsp;</span>Ignite Testing</a></span></li><li><span><a href="#NLPBook-Testing" data-toc-modified-id="NLPBook-Testing-5.2"><span class="toc-item-num">5.2&nbsp;&nbsp;</span>NLPBook Testing</a></span></li></ul></li><li><span><a href="#Inference" data-toc-modified-id="Inference-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Inference</a></span><ul class="toc-item"><li><span><a href="#Single-Inference" data-toc-modified-id="Single-Inference-6.1"><span class="toc-item-num">6.1&nbsp;&nbsp;</span>Single Inference</a></span></li><li><span><a href="#TopK-Inference" data-toc-modified-id="TopK-Inference-6.2"><span class="toc-item-num">6.2&nbsp;&nbsp;</span>TopK Inference</a></span></li></ul></li></ul></div>

# Surname Classifier with MLP

Classifying surnames based on national origin.

## Imports & Inits

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
import torch

from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from pathlib import Path

In [3]:
from ignite.engine import Events, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.contrib.handlers import ProgressBar

In [4]:
from surname.dataset import ProjectDataset
from surname.mlp_model import Classifier
from surname.trainer import Trainer
from surname.args import args

In [5]:
path = Path('../data/surnames')

## Data Preparation

In [6]:
work_dir = path/'work_dir'
surnames_csv = path/args.proc_dataset_csv
vectorizer_path = work_dir/args.vectorizer_fname
cw_file = work_dir/args.cw_file

args

Namespace(batch_size=64, checkpointer_name='classifier', checkpointer_prefix='surname', cw_file='class_weights.pt', device='cuda:3', early_stopping_criteria=5, hidden_dim=300, learning_rate=0.001, model_dir='models', num_epochs=100, proc_dataset_csv='surnames_with_splits.csv', raw_dataset_csv='surnames.csv', save_every=2, save_total=5, train_proportion=0.7, vectorizer_fname='vectorizer.json')

In [7]:
df = pd.read_csv(surnames_csv)
len(df)

10980

In [10]:
is_load = True

In [9]:
if not is_load:
  train_ds = ProjectDataset.load_data_and_create_vectorizer(df.loc[df['split'] == 'train'])
  train_ds.save_vectorizer(vectorizer_path)
  vectorizer = train_ds.get_vectorizer()
  class_counts = df['nationality'].value_counts().to_dict()
  sorted_counts = sorted(class_counts.items(), key=lambda x: vectorizer.nationality_vocab.lookup_token(x[0]))
  freq = [count for _, count in sorted_counts]
  class_weights = 1.0/torch.tensor(freq, dtype=torch.float32)
  torch.save(class_weights, cw_file)

In [11]:
train_df = df.loc[df['split'] == 'train']
train_ds = ProjectDataset.load_data_and_vectorizer(train_df, vectorizer_path)
vectorizer = train_ds.get_vectorizer()
class_weights = torch.load(cw_file)
train_dl = DataLoader(train_ds, args.batch_size, shuffle=True, drop_last=True)

In [12]:
val_df = df.loc[df['split'] == 'val']
val_ds = ProjectDataset.load_data_and_vectorizer(val_df, vectorizer_path)
val_dl = DataLoader(val_ds, args.batch_size, shuffle=True, drop_last=True)

In [13]:
test_df = df.loc[df['split'] == 'test']
test_ds = ProjectDataset.load_data_and_vectorizer(test_df, vectorizer_path)
test_dl = DataLoader(test_ds, args.batch_size, shuffle=True, drop_last=True)

In [14]:
len(train_dl.dataset), len(val_dl.dataset), len(test_dl.dataset)

(7680, 1640, 1660)

## Model

In [15]:
classifier = Classifier(input_dim=len(vectorizer.surname_vocab), hidden_dim=args.hidden_dim, output_dim=len(vectorizer.nationality_vocab))
optimizer = optim.Adam(classifier.parameters(), lr=args.learning_rate)
class_weights = class_weights.to(args.device)
loss_func = nn.CrossEntropyLoss(class_weights)

pbar = ProgressBar(persist=True)
metrics = {'accuracy': Accuracy(), 'loss': Loss(loss_func)}

## Training

In [16]:
surname_trainer = Trainer(classifier, optimizer, loss_func, train_dl, val_dl, work_dir, args, pbar, metrics)
surname_trainer.run()

Epoch [1/100]: [120/120] 100%|██████████, loss=2.70e+00 [00:01<00:00]
Epoch [2/100]: [15/120]  12%|█▎        , loss=2.66e+00 [00:00<00:01]

Epoch: 1
Training - Loss: 2.535, Accuracy: 0.426
Validation - Loss: 2.569, Accuracy: 0.422


Epoch [2/100]: [120/120] 100%|██████████, loss=2.35e+00 [00:01<00:00]
Epoch [3/100]: [15/120]  12%|█▎        , loss=2.27e+00 [00:00<00:01]

Epoch: 2
Training - Loss: 2.144, Accuracy: 0.408
Validation - Loss: 2.231, Accuracy: 0.383


Epoch [3/100]: [120/120] 100%|██████████, loss=2.06e+00 [00:01<00:00]
Epoch [4/100]: [16/120]  13%|█▎        , loss=1.65e+00 [00:00<00:01]

Epoch: 3
Training - Loss: 1.903, Accuracy: 0.449
Validation - Loss: 2.037, Accuracy: 0.433


Epoch [4/100]: [120/120] 100%|██████████, loss=1.86e+00 [00:01<00:00]
Epoch [5/100]: [15/120]  12%|█▎        , loss=1.80e+00 [00:00<00:01]

Epoch: 4
Training - Loss: 1.758, Accuracy: 0.456
Validation - Loss: 1.950, Accuracy: 0.427


Epoch [5/100]: [120/120] 100%|██████████, loss=1.80e+00 [00:01<00:00]
Epoch [6/100]: [16/120]  13%|█▎        , loss=1.50e+00 [00:00<00:01]

Epoch: 5
Training - Loss: 1.658, Accuracy: 0.465
Validation - Loss: 1.905, Accuracy: 0.429


Epoch [6/100]: [120/120] 100%|██████████, loss=1.69e+00 [00:01<00:00]
Epoch [7/100]: [15/120]  12%|█▎        , loss=1.78e+00 [00:00<00:01]

Epoch: 6
Training - Loss: 1.596, Accuracy: 0.483
Validation - Loss: 1.865, Accuracy: 0.439


Epoch [7/100]: [120/120] 100%|██████████, loss=1.68e+00 [00:01<00:00]
Epoch [8/100]: [16/120]  13%|█▎        , loss=1.42e+00 [00:00<00:01]

Epoch: 7
Training - Loss: 1.549, Accuracy: 0.490
Validation - Loss: 1.839, Accuracy: 0.448


Epoch [8/100]: [120/120] 100%|██████████, loss=1.60e+00 [00:01<00:00]
Epoch [9/100]: [16/120]  13%|█▎        , loss=1.73e+00 [00:00<00:01]

Epoch: 8
Training - Loss: 1.504, Accuracy: 0.510
Validation - Loss: 1.808, Accuracy: 0.459


Epoch [9/100]: [120/120] 100%|██████████, loss=1.61e+00 [00:01<00:00]
Epoch [10/100]: [16/120]  13%|█▎        , loss=1.38e+00 [00:00<00:01]

Epoch: 9
Training - Loss: 1.466, Accuracy: 0.485
Validation - Loss: 1.807, Accuracy: 0.441


Epoch [10/100]: [120/120] 100%|██████████, loss=1.55e+00 [00:01<00:00]
Epoch [11/100]: [16/120]  13%|█▎        , loss=1.43e+00 [00:00<00:01]

Epoch: 10
Training - Loss: 1.435, Accuracy: 0.499
Validation - Loss: 1.764, Accuracy: 0.453


Epoch [11/100]: [120/120] 100%|██████████, loss=1.50e+00 [00:01<00:00]
Epoch [12/100]: [16/120]  13%|█▎        , loss=1.95e+00 [00:00<00:01]

Epoch: 11
Training - Loss: 1.408, Accuracy: 0.509
Validation - Loss: 1.771, Accuracy: 0.455


Epoch [12/100]: [120/120] 100%|██████████, loss=1.55e+00 [00:01<00:00]
Epoch [13/100]: [16/120]  13%|█▎        , loss=1.95e+00 [00:00<00:01]

Epoch: 12
Training - Loss: 1.382, Accuracy: 0.503
Validation - Loss: 1.764, Accuracy: 0.446


Epoch [13/100]: [120/120] 100%|██████████, loss=1.57e+00 [00:01<00:00]
Epoch [14/100]: [16/120]  13%|█▎        , loss=1.47e+00 [00:00<00:01]

Epoch: 13
Training - Loss: 1.360, Accuracy: 0.495
Validation - Loss: 1.745, Accuracy: 0.451


Epoch [14/100]: [120/120] 100%|██████████, loss=1.49e+00 [00:01<00:00]
Epoch [15/100]: [15/120]  12%|█▎        , loss=1.16e+00 [00:00<00:01]

Epoch: 14
Training - Loss: 1.327, Accuracy: 0.520
Validation - Loss: 1.764, Accuracy: 0.463


Epoch [15/100]: [120/120] 100%|██████████, loss=1.41e+00 [00:01<00:00]
Epoch [16/100]: [16/120]  13%|█▎        , loss=1.42e+00 [00:00<00:01]

Epoch: 15
Training - Loss: 1.313, Accuracy: 0.515
Validation - Loss: 1.752, Accuracy: 0.456


Epoch [16/100]: [120/120] 100%|██████████, loss=1.43e+00 [00:01<00:00]
Epoch [17/100]: [15/120]  12%|█▎        , loss=1.25e+00 [00:00<00:01]

Epoch: 16
Training - Loss: 1.290, Accuracy: 0.532
Validation - Loss: 1.728, Accuracy: 0.477


Epoch [17/100]: [120/120] 100%|██████████, loss=1.40e+00 [00:01<00:00]
Epoch [18/100]: [16/120]  13%|█▎        , loss=1.49e+00 [00:00<00:01]

Epoch: 17
Training - Loss: 1.274, Accuracy: 0.537
Validation - Loss: 1.719, Accuracy: 0.479


Epoch [18/100]: [120/120] 100%|██████████, loss=1.41e+00 [00:01<00:00]
Epoch [19/100]: [15/120]  12%|█▎        , loss=1.14e+00 [00:00<00:01]

Epoch: 18
Training - Loss: 1.251, Accuracy: 0.542
Validation - Loss: 1.720, Accuracy: 0.480


Epoch [19/100]: [120/120] 100%|██████████, loss=1.35e+00 [00:01<00:00]
Epoch [20/100]: [16/120]  13%|█▎        , loss=1.51e+00 [00:00<00:01]

Epoch: 19
Training - Loss: 1.228, Accuracy: 0.531
Validation - Loss: 1.710, Accuracy: 0.479


Epoch [20/100]: [120/120] 100%|██████████, loss=1.39e+00 [00:01<00:00]
Epoch [21/100]: [15/120]  12%|█▎        , loss=1.21e+00 [00:00<00:01]

Epoch: 20
Training - Loss: 1.220, Accuracy: 0.539
Validation - Loss: 1.690, Accuracy: 0.482


Epoch [21/100]: [120/120] 100%|██████████, loss=1.33e+00 [00:01<00:00]
Epoch [22/100]: [17/120]  14%|█▍        , loss=1.30e+00 [00:00<00:01]

Epoch: 21
Training - Loss: 1.200, Accuracy: 0.560
Validation - Loss: 1.698, Accuracy: 0.498


Epoch [22/100]: [120/120] 100%|██████████, loss=1.35e+00 [00:01<00:00]
Epoch [23/100]: [15/120]  12%|█▎        , loss=1.55e+00 [00:00<00:01]

Epoch: 22
Training - Loss: 1.183, Accuracy: 0.549
Validation - Loss: 1.727, Accuracy: 0.477


Epoch [23/100]: [120/120] 100%|██████████, loss=1.34e+00 [00:01<00:00]
Epoch [24/100]: [16/120]  13%|█▎        , loss=1.67e+00 [00:00<00:01]

Epoch: 23
Training - Loss: 1.163, Accuracy: 0.576
Validation - Loss: 1.706, Accuracy: 0.511


Epoch [24/100]: [120/120] 100%|██████████, loss=1.35e+00 [00:01<00:00]
Epoch [25/100]: [15/120]  12%|█▎        , loss=1.04e+00 [00:00<00:01]

Epoch: 24
Training - Loss: 1.146, Accuracy: 0.553
Validation - Loss: 1.687, Accuracy: 0.486


Epoch [25/100]: [120/120] 100%|██████████, loss=1.28e+00 [00:01<00:00]
Epoch [26/100]: [16/120]  13%|█▎        , loss=1.32e+00 [00:00<00:01]

Epoch: 25
Training - Loss: 1.129, Accuracy: 0.570
Validation - Loss: 1.691, Accuracy: 0.506


Epoch [26/100]: [120/120] 100%|██████████, loss=1.28e+00 [00:01<00:00]
Epoch [27/100]: [15/120]  12%|█▎        , loss=1.24e+00 [00:00<00:01]

Epoch: 26
Training - Loss: 1.115, Accuracy: 0.569
Validation - Loss: 1.702, Accuracy: 0.500


Epoch [27/100]: [120/120] 100%|██████████, loss=1.29e+00 [00:01<00:00]
Epoch [28/100]: [17/120]  14%|█▍        , loss=1.19e+00 [00:00<00:01]

Epoch: 27
Training - Loss: 1.101, Accuracy: 0.588
Validation - Loss: 1.681, Accuracy: 0.521


Epoch [28/100]: [120/120] 100%|██████████, loss=1.24e+00 [00:01<00:00]
Epoch [29/100]: [16/120]  13%|█▎        , loss=1.12e+00 [00:00<00:01]

Epoch: 28
Training - Loss: 1.083, Accuracy: 0.570
Validation - Loss: 1.675, Accuracy: 0.498


Epoch [29/100]: [120/120] 100%|██████████, loss=1.23e+00 [00:01<00:00]
Epoch [30/100]: [16/120]  13%|█▎        , loss=1.03e+00 [00:00<00:01]

Epoch: 29
Training - Loss: 1.073, Accuracy: 0.573
Validation - Loss: 1.671, Accuracy: 0.496


Epoch [30/100]: [120/120] 100%|██████████, loss=1.21e+00 [00:01<00:00]
Epoch [31/100]: [15/120]  12%|█▎        , loss=1.12e+00 [00:00<00:01]

Epoch: 30
Training - Loss: 1.065, Accuracy: 0.581
Validation - Loss: 1.678, Accuracy: 0.504


Epoch [31/100]: [120/120] 100%|██████████, loss=1.21e+00 [00:01<00:00]
Epoch [32/100]: [17/120]  14%|█▍        , loss=1.19e+00 [00:00<00:01]

Epoch: 31
Training - Loss: 1.045, Accuracy: 0.583
Validation - Loss: 1.670, Accuracy: 0.498


Epoch [32/100]: [120/120] 100%|██████████, loss=1.19e+00 [00:01<00:00]
Epoch [33/100]: [15/120]  12%|█▎        , loss=1.15e+00 [00:00<00:01]

Epoch: 32
Training - Loss: 1.040, Accuracy: 0.591
Validation - Loss: 1.688, Accuracy: 0.509


Epoch [33/100]: [120/120] 100%|██████████, loss=1.21e+00 [00:01<00:00]
Epoch [34/100]: [16/120]  13%|█▎        , loss=1.21e+00 [00:00<00:01]

Epoch: 33
Training - Loss: 1.019, Accuracy: 0.600
Validation - Loss: 1.701, Accuracy: 0.517


Epoch [34/100]: [120/120] 100%|██████████, loss=1.18e+00 [00:01<00:00]
Epoch [35/100]: [16/120]  13%|█▎        , loss=1.18e+00 [00:00<00:01]

Epoch: 34
Training - Loss: 1.005, Accuracy: 0.610
Validation - Loss: 1.674, Accuracy: 0.522


Epoch [35/100]: [120/120] 100%|██████████, loss=1.18e+00 [00:01<00:00]
Epoch [36/100]: [17/120]  14%|█▍        , loss=1.12e+00 [00:00<00:01]

Epoch: 35
Training - Loss: 0.991, Accuracy: 0.598
Validation - Loss: 1.659, Accuracy: 0.513


Epoch [36/100]: [120/120] 100%|██████████, loss=1.16e+00 [00:01<00:00]
Epoch [37/100]: [16/120]  13%|█▎        , loss=1.09e+00 [00:00<00:01]

Epoch: 36
Training - Loss: 0.982, Accuracy: 0.608
Validation - Loss: 1.663, Accuracy: 0.526


Epoch [37/100]: [120/120] 100%|██████████, loss=1.13e+00 [00:01<00:00]
Epoch [38/100]: [16/120]  13%|█▎        , loss=1.15e+00 [00:00<00:01]

Epoch: 37
Training - Loss: 0.961, Accuracy: 0.612
Validation - Loss: 1.668, Accuracy: 0.524


Epoch [38/100]: [120/120] 100%|██████████, loss=1.13e+00 [00:01<00:00]
Epoch [39/100]: [15/120]  12%|█▎        , loss=1.12e+00 [00:00<00:01]

Epoch: 38
Training - Loss: 0.953, Accuracy: 0.611
Validation - Loss: 1.664, Accuracy: 0.524


Epoch [39/100]: [120/120] 100%|██████████, loss=1.14e+00 [00:01<00:00]
Epoch [40/100]: [17/120]  14%|█▍        , loss=8.80e-01 [00:00<00:01]

Epoch: 39
Training - Loss: 0.941, Accuracy: 0.619
Validation - Loss: 1.684, Accuracy: 0.526


Epoch [40/100]: [120/120] 100%|██████████, loss=1.09e+00 [00:01<00:00]


Epoch: 40
Training - Loss: 0.934, Accuracy: 0.625
Validation - Loss: 1.701, Accuracy: 0.538


## Testing

### Ignite Testing

In [17]:
args.device = 'cpu'
classifier = Classifier(input_dim=len(vectorizer.surname_vocab), hidden_dim=args.hidden_dim,\
                        output_dim=len(vectorizer.nationality_vocab))
state_dict = torch.load(work_dir/args.model_dir/'surname_classifier_40.pth')
# state_dict = torch.load(work_dir/'surname_classifier_model.pth')
classifier.load_state_dict(state_dict)

class_weights = class_weights.to(args.device)
loss_func = nn.CrossEntropyLoss(class_weights)
metrics = {'accuracy': Accuracy(), 'loss': Loss(loss_func)}

In [18]:
evaluator = create_supervised_evaluator(classifier, metrics=metrics)

@evaluator.on(Events.COMPLETED)
def log_testing_results(engine):
  metrics = engine.state.metrics
  print(f"Test loss: {metrics['loss']:0.3f}")
  print(f"Test accuracy: {metrics['accuracy']:0.3f}")

In [19]:
evaluator.run(test_dl)

Test loss: 1.717
Test accuracy: 0.549


<ignite.engine.engine.State at 0x7fc4de0520b8>

### NLPBook Testing

In [20]:
def compute_accuracy(y_pred, y_target):
  _, y_pred_indices = y_pred.max(dim=1)
  n_correct = torch.eq(y_pred_indices, y_target).sum().item()
  return n_correct / len(y_pred_indices) * 100

In [21]:
running_loss = 0.
running_acc = 0.

classifier.eval()
for i, batch in enumerate(test_dl):
  x,y = batch
  y_pred = classifier(x_in=x.float())
  
  loss = loss_func(y_pred, y)
  loss_t = loss.item()
  running_loss += (loss_t-running_loss)/(i+1)
  
  acc_t = compute_accuracy(y_pred, y)
  running_acc += (acc_t-running_acc)/(i+1)

In [22]:
print(f"Test loss: {running_loss:0.3f}")
print(f"Test acc: {running_acc:0.3f}")

Test loss: 1.717
Test acc: 54.562


## Inference

### Single Inference

In [23]:
def predict_natinoality(surname, classifier, vectorizer):
  """
    Predict the nationality from a new surname
    
    Args:
      surname: the surname to classify
      classifier: an instance of the classifier
      vectorizer: the corresponding vectorizer
      
    Returns:
      a dictionary with most likely natinoality and its probability
  """
  vectorized_surname = vectorizer.vectorize(surname)
  vectorized_surname = torch.tensor(vectorized_surname).view(1,-1)
  result = classifier(vectorized_surname, apply_softmax=True)
  
  probability_values, indices = result.max(dim=1)
  idx = indices.item()
  
  predicted_nationality = vectorizer.nationality_vocab.lookup_idx(idx)
  probability_value = probability_values.item()
  
  return {'nationality': predicted_nationality, 'probability': probability_value}

In [25]:
new_surname = input("Enter a surname to classify: ")
prediction = predict_natinoality(new_surname, classifier, vectorizer)
print(f"{new_surname} -> {prediction['nationality']} p={prediction['probability']:0.2f}")

Enter a surname to classify: Srinivasan
Srinivasan -> Italian p=0.36


### TopK Inference

In [26]:
def predict_topk_nationality(name, classifier, vectorizer, k=5):
  vectorized_name = vectorizer.vectorize(name)
  vectorized_name = torch.tensor(vectorized_name).view(1, -1)
  prediction_vector = classifier(vectorized_name, apply_softmax=True)
  probability_values, indices = torch.topk(prediction_vector, k=k)

  # returned size is 1,k
  probability_values = probability_values.detach().numpy()[0]
  indices = indices.detach().numpy()[0]

  results = []
  for prob_value, idx in zip(probability_values, indices):
      nationality = vectorizer.nationality_vocab.lookup_idx(idx)
      results.append({'nationality': nationality, 
                      'probability': prob_value})

  return results

In [27]:
new_surname = input("Enter a surname to classify: ")
classifier = classifier.to("cpu")

k = int(input("How many of the top predictions to see? "))
if k > len(vectorizer.nationality_vocab):
  print("Sorry! That's more than the # of nationalities we have.. defaulting you to max size :)")
  k = len(vectorizer.nationality_vocab)
    
predictions = predict_topk_nationality(new_surname, classifier, vectorizer, k=k)

print("Top {} predictions:".format(k))
print("===================")
for prediction in predictions:
  print(f"{new_surname} -> {prediction['nationality']} p={prediction['probability']:0.2f}")

Enter a surname to classify: Srinivasan
How many of the top predictions to see? 5
Top 5 predictions:
Srinivasan -> Italian p=0.36
Srinivasan -> Russian p=0.15
Srinivasan -> Greek p=0.13
Srinivasan -> French p=0.12
Srinivasan -> Czech p=0.08
