In [1]:
import collections
import numpy as np
import pandas as pd
import re

from argparse import Namespace

In [2]:
args = Namespace(
    raw_data="../data/surnames/surnames.csv",
    train_prop=0.7,
    val_prop=0.15,
    test_prop=0.15,
    processed_data="../data/surnames/surnames_processed_data.csv",
    seed=1337
)

In [3]:
surnames = pd.read_csv(args.raw_data, header=0)

In [4]:
surnames.head()

Unnamed: 0,surname,nationality
0,Woodford,English
1,Coté,French
2,Kore,English
3,Koury,Arabic
4,Lebzak,Russian


In [5]:
# unique classes
set(surnames.nationality)

{'Arabic',
 'Chinese',
 'Czech',
 'Dutch',
 'English',
 'French',
 'German',
 'Greek',
 'Irish',
 'Italian',
 'Japanese',
 'Korean',
 'Polish',
 'Portuguese',
 'Russian',
 'Scottish',
 'Spanish',
 'Vietnamese'}

In [6]:
# split data by nationality
by_nationality = collections.defaultdict(list)
for _, row in surnames.iterrows():
    by_nationality[row.nationality].append(row.to_dict())

In [7]:
# print one record
by_nationality['Arabic'][0]

{'surname': 'Koury', 'nationality': 'Arabic'}

In [8]:
final_list = []
np.random.seed(args.seed)
for _, item_list in sorted(by_nationality.items()):
    np.random.shuffle(item_list)
    n = len(item_list)
    n_train = int(args.train_prop*n)
    n_val = int(args.val_prop*n)
    n_test = int(args.test_prop*n)
    
    # assign data point a split attribute
    for item in item_list[:n_train]:
        item['split'] = 'train'
    for item in item_list[n_train:n_train+n_val]:
        item['split'] = 'val'
    for item in item_list[n_train+n_val:]:
        item['split'] = 'test'
        
    final_list.extend(item_list)

In [9]:
final_list[0]

{'surname': 'Totah', 'nationality': 'Arabic', 'split': 'train'}

In [10]:
processed_surnames_data = pd.DataFrame(final_list)

In [11]:
processed_surnames_data.split.value_counts()

train    7680
test     1660
val      1640
Name: split, dtype: int64

In [12]:
processed_surnames_data.head()

Unnamed: 0,surname,nationality,split
0,Totah,Arabic,train
1,Abboud,Arabic,train
2,Fakhoury,Arabic,train
3,Srour,Arabic,train
4,Sayegh,Arabic,train


In [13]:
processed_surnames_data.to_csv(args.processed_data, index=False)