# Stratified randomization of sessions into testing and validation groups

To use this, first run `evaluate_data.ipynb` on all the sessions in the batch. Then _____

In [None]:
import pandas as pd
import numpy as np

# np.random.seed(0)     # for testing
np.random.seed()

# Inputs
per_trial_csvs = [
    'nimr_Tr1_fe-beyond_frame_wsize-5_winc-5_dynamic-scalar-2.1_batch-1_20250529-092422_per_trial.csv',
    # 'nimr_Tr1_fe-beyond_frame_wsize-5_winc-5_batch_2_20250515-163038_per_trial.csv',
]
age_bins = [0, 0.26, 0.51, 1.01, 3.01, 5]       # in years
age_labels = ['0-3 mo', '4-6 mo', '7-12 mo', '13-36 mo', '37-59 mo']
nxx_to_location = {
    'N07':'Lagos',
    'N08':'eHA',
    'N09':'Abuja',
    'N10':'NH',
    'N11':'Kano',
    'N12':'Kano',
    'N14':'Abuja',
    'N15':'Kano',
    'N16':'Kano',
    'N17':'Abuja',
    'N18':'eHA',
    'N19':'Lagos',
}

# Read the results, skipping the second row, which has sub-headings we don't need
if len(per_trial_csvs) == 1:
    df = pd.read_csv(f"../data/results/{per_trial_csvs[0]}", skiprows=[1])
else:
    dfs = list(pd.DataFrame())
    for filename in per_trial_csvs:
        dfs.append(pd.read_csv(f"../data/results/{filename}", skiprows=[1], low_memory=False))

    df = pd.concat(dfs).copy()      # The copy() avoids fragmentation warnings later.

df = df[['id', 'age']]
df['age'] = pd.to_numeric(df['age'], errors='coerce')

# Stratify by age bins, in years
df['age group'] = pd.cut(df['age'], bins=age_bins, labels=age_labels, include_lowest=True, right=False)

# Look up the location of each Nxx
df['system'] = df['id'].str.slice(0, 3)
df['location'] = df['system'].map(nxx_to_location)
df_no_loc = df[df['location'].isnull()]
if len(df_no_loc):
    print('Warning: No location (study site) was found for the following sessions:')
    display(df_no_loc)

# Define groups (buckets)
categories = list(range(1, 11))

# Randomly rotate the categories so that the first doesn't always receive the first participant
rotation = np.random.randint(0, len(categories))
categories = categories[rotation:] + categories[:rotation]

# Exclude any rows that exceed the max age group
df_exceed = df[df['age group'].isnull()]
if len(df_exceed):
    print('Warning: The following sessions were excluded for having an invalid age.')
    df = df[~df['age group'].isnull()]
    display(df_exceed)
else:
    print('All sessions have a valid age between 0 and 5 years.')

All sessions have a valid age between 0 and 5 years.


In [2]:
# Assign to groups

# Add a dummy column for shuffling within strata
df['_shuffle'] = np.random.rand(len(df))

# Sort by strata + random shuffle
df = df.sort_values(by=['age group', 'location', '_shuffle'])

# Assign groups cyclically
df.insert(0, 'assigned group', [categories[i % len(categories)] for i in range(len(df))])

# Cleanup
df = df.drop(columns=['_shuffle'])

df = df.sort_values(by=['assigned group', 'id'])
display(df)

# TODO: Save to .csv

Unnamed: 0,assigned group,id,age,age group,system,location
8,1,N07-025,1.08,13-36 mo,N07,Lagos
33,1,N09-047,4.92,37-59 mo,N09,Abuja
45,1,N09-059,0.25,0-3 mo,N09,Abuja
48,1,N09-062,0.83,7-12 mo,N09,Abuja
53,1,N09-067,3.00,13-36 mo,N09,Abuja
...,...,...,...,...,...,...
59,10,N12-038,4.00,37-59 mo,N12,Kano
60,10,N12-039,2.00,13-36 mo,N12,Kano
80,10,N12-059,4.00,37-59 mo,N12,Kano
83,10,N12-062,0.25,0-3 mo,N12,Kano


In [3]:
# See the distribution

# pd.set_option('display.precision', 0)
display(pd.pivot_table(df, values='id', index=['age group', 'location'], columns=['assigned group'], aggfunc='count', observed=False, margins=True).fillna(0))

Unnamed: 0_level_0,assigned group,1,2,3,4,5,6,7,8,9,10,All
age group,location,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
0-3 mo,Abuja,1,1,1,1,1,0,0,0,0,1,6
0-3 mo,Kano,1,0,0,0,0,1,1,1,1,1,6
0-3 mo,Lagos,0,1,1,0,0,0,0,0,0,0,2
4-6 mo,Abuja,0,0,0,1,1,1,0,0,0,0,3
4-6 mo,Kano,0,0,0,0,0,0,1,0,0,0,1
4-6 mo,Lagos,0,0,0,0,0,0,0,1,0,0,1
7-12 mo,Abuja,1,1,1,0,0,0,0,0,1,1,5
7-12 mo,Kano,0,0,0,1,1,1,1,0,0,0,4
7-12 mo,Lagos,0,0,0,0,0,0,0,1,0,0,1
13-36 mo,Abuja,1,1,1,1,1,1,1,1,2,1,11


In [None]:
# See the sessions assigned to one group

display(df[df["assigned group"] == 2])

Unnamed: 0,assigned group,id,age,age group,system,location
3,2,N07-020,0.0,0-3 mo,N07,Lagos
12,2,N07-029,2.42,13-36 mo,N07,Lagos
21,2,N09-035,3.42,37-59 mo,N09,Abuja
31,2,N09-045,0.58,7-12 mo,N09,Abuja
32,2,N09-046,0.17,0-3 mo,N09,Abuja
44,2,N09-058,1.92,13-36 mo,N09,Abuja
57,2,N12-036,2.33,13-36 mo,N12,Kano
71,2,N12-050,4.33,37-59 mo,N12,Kano
92,2,N12-071,4.42,37-59 mo,N12,Kano
94,2,N12-073,2.0,13-36 mo,N12,Kano


In [4]:
# See the distribution of sessions, without regard to our new group assignments

display(pd.pivot_table(df, values='id', index=['age group'], columns=['location'], aggfunc='count', observed=False, margins=True, fill_value=0))

location,Abuja,Kano,Lagos,All
age group,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0-3 mo,6,6,2,14
4-6 mo,3,1,1,5
7-12 mo,5,4,1,10
13-36 mo,11,13,14,38
37-59 mo,10,17,6,33
All,35,41,24,100


In [5]:
# See the distribution of sessions, using system instead of location

display(pd.pivot_table(df, values='id', index=['age group'], columns=['location', 'system'], aggfunc='count', observed=True, margins=True, fill_value=0))

location,Abuja,Kano,Lagos,Lagos,All
system,N09,N12,N07,N19,Unnamed: 5_level_1
age group,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
0-3 mo,6,6,2,0,14
4-6 mo,3,1,1,0,5
7-12 mo,5,4,1,0,10
13-36 mo,11,13,10,4,38
37-59 mo,10,17,5,1,33
All,35,41,19,5,100
