In [22]:
import collections
import numpy as np
import pandas as pd
import re

from argparse import Namespace

In [23]:
args = Namespace(
    raw_dataset_csv="../Data/surnames.csv",
    train_proportion=0.7,
    val_proportion=0.15,
    test_proportion=0.15,
    output_munged_csv="../Data/surnames_with_splits.csv",
    seed=1337
)

In [24]:
# Read raw data
surnames = pd.read_csv(args.raw_dataset_csv, header=0)

In [25]:
surnames.head()

Unnamed: 0,surname,nationality
0,Woodford,English
1,Coté,French
2,Kore,English
3,Koury,Arabic
4,Lebzak,Russian


In [26]:
print(f"{set(surnames.nationality)}")
surnames.nationality.value_counts()

{'Czech', 'Japanese', 'Irish', 'Arabic', 'Polish', 'English', 'Vietnamese', 'Chinese', 'Dutch', 'Scottish', 'Greek', 'German', 'French', 'Russian', 'Portuguese', 'Spanish', 'Italian', 'Korean'}


English       2972
Russian       2373
Arabic        1603
Japanese       775
Italian        600
German         576
Czech          414
Spanish        258
Dutch          236
French         229
Chinese        220
Irish          183
Greek          156
Polish         120
Korean          77
Scottish        75
Vietnamese      58
Portuguese      55
Name: nationality, dtype: int64

In [27]:
# Splitting train by nationality
# Create dict
by_nationality = collections.defaultdict(list)
for _, row in surnames.iterrows():
    by_nationality[row.nationality].append(row.to_dict())

In [28]:
prev = ''
for key, value in by_nationality.items():
    if prev is not key and len(value) > 0: print(f"{key}:{value[0]}")
    prev = key
    

English:{'surname': 'Woodford', 'nationality': 'English'}
French:{'surname': 'Coté', 'nationality': 'French'}
Arabic:{'surname': 'Koury', 'nationality': 'Arabic'}
Russian:{'surname': 'Lebzak', 'nationality': 'Russian'}
Japanese:{'surname': 'Obinata', 'nationality': 'Japanese'}
Chinese:{'surname': 'Zhuan', 'nationality': 'Chinese'}
Italian:{'surname': 'Acconci', 'nationality': 'Italian'}
Czech:{'surname': 'Michalovicova', 'nationality': 'Czech'}
Irish:{'surname': 'Gallchobhar', 'nationality': 'Irish'}
German:{'surname': 'Strobel', 'nationality': 'German'}
Greek:{'surname': 'Paloumbas', 'nationality': 'Greek'}
Spanish:{'surname': 'Vargas', 'nationality': 'Spanish'}
Polish:{'surname': 'Kaczka', 'nationality': 'Polish'}
Dutch:{'surname': 'Rooijakkers', 'nationality': 'Dutch'}
Vietnamese:{'surname': 'Vinh', 'nationality': 'Vietnamese'}
Korean:{'surname': 'Ahn', 'nationality': 'Korean'}
Portuguese:{'surname': 'Pereira', 'nationality': 'Portuguese'}
Scottish:{'surname': 'Burns', 'nationality'

In [29]:
# Create split data
final_list = []
np.random.seed(args.seed)
for _, item_list in sorted(by_nationality.items()):
    np.random.shuffle(item_list)
    n = len(item_list)
    n_train = int(args.train_proportion*n)
    n_val = int(args.val_proportion*n)
    n_test = int(args.test_proportion*n)
    
    # Give data point a split attribute
    for item in item_list[:n_train]:
        item['split'] = 'train'
    for item in item_list[n_train:n_train+n_val]:
        item['split'] = 'val'
    for item in item_list[n_train+n_val:]:
        item['split'] = 'test'  
    
    # Add to final list
    final_list.extend(item_list)

In [30]:
# Write split data to file
final_surnames = pd.DataFrame(final_list)

In [31]:
final_surnames.split.value_counts()

train    7680
test     1660
val      1640
Name: split, dtype: int64

In [32]:
final_surnames.head()

Unnamed: 0,surname,nationality,split
0,Totah,Arabic,train
1,Abboud,Arabic,train
2,Fakhoury,Arabic,train
3,Srour,Arabic,train
4,Sayegh,Arabic,train


In [33]:
# Write munged data to CSV
final_surnames.to_csv(args.output_munged_csv, index=False)