## Deep Learning Cohort Fertility Prediction

Trains the non-log DL model with rotating jump-off years (1985-2010) and converts predictions to cohort rates for comparison with the Lee (1993) benchmark.

Aligned with R Lee-Carter approach:
- Ages 15-44 (matching `age1=15, age2=44`)
- Country filtering: skips countries missing any age in 15-44
- Year gap checks: skips country/JOY combos with gaps in observed years
- Forecast horizon: JOY + 30 years (matching `len=30`)

In [22]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import random

In [23]:
import training_functions
import importlib

importlib.reload(training_functions)

<module 'training_functions' from '/Users/paigepark/Desktop/repos/deep-fert/code/training_functions.py'>

In [24]:
asfr_training = np.loadtxt('../data/asfr_training.txt')
asfr_test = np.loadtxt('../data/asfr_test.txt')
asfr_all_raw = np.vstack([asfr_training, asfr_test])

# Save unfiltered data for R scripts
np.savetxt('../data/asfr_1950_to_2015.txt', asfr_all_raw)

# Filter to ages 15-44 to match R's Lee-Carter approach (age1=15, age2=44)
AGE_MIN = 15
AGE_MAX = 44
age_mask = (asfr_all_raw[:, 2] >= AGE_MIN) & (asfr_all_raw[:, 2] <= AGE_MAX)
asfr_all = asfr_all_raw[age_mask]

print(f"Raw: {asfr_all_raw.shape}")
print(f"Filtered (ages {AGE_MIN}-{AGE_MAX}): {asfr_all.shape}")
print(f"Years: {int(asfr_all[:,1].min())}-{int(asfr_all[:,1].max())}")
print(f"Ages: {int(asfr_all[:,2].min())}-{int(asfr_all[:,2].max())}")
print(f"Countries: {int(asfr_all[:,0].max()) + 1}")

Raw: (92190, 4)
Filtered (ages 15-44): (65850, 4)
Years: 1950-2015
Ages: 15-44
Countries: 39


In [25]:
JUMP_OFF_YEARS = [1985, 1990, 1995, 2000, 2005, 2010]
YEAR_MIN = 1950
YEAR_MAX = 2015
FORECAST_LEN = 30  # Match R's len=30
STEPS_RATIO = 4.74
BATCH_SIZE = 256
METHOD_NAME = "DL_NonLog"

geo_dim = int(asfr_all[:, 0].max()) + 1
ages = np.arange(AGE_MIN, AGE_MAX + 1)
countries = np.arange(geo_dim)

# Country validation: skip countries missing any age in 15-44 (matching R)
required_ages = set(range(AGE_MIN, AGE_MAX + 1))
valid_countries = []
country_min_year = {}

for c in countries:
    c_data = asfr_all[asfr_all[:, 0] == c]
    if len(c_data) == 0:
        continue
    c_ages = set(c_data[:, 2].astype(int))
    missing = required_ages - c_ages
    if missing:
        print(f"Skipping Country {int(c)} - Missing ages: {sorted(missing)}")
        continue
    valid_countries.append(int(c))
    country_min_year[int(c)] = int(c_data[:, 1].min())

# For each JOY, check for year gaps per country (matching R)
valid_combos = {}  # joy -> list of valid country indices
for joy in JUMP_OFF_YEARS:
    valid_for_joy = []
    for c in valid_countries:
        min_yr = country_min_year[c]
        if joy > YEAR_MAX or joy < min_yr:
            continue
        # Check all years from min_data_year to joy are present
        c_data = asfr_all[asfr_all[:, 0] == c]
        available_years = set(c_data[:, 1].astype(int))
        required_years = set(range(min_yr, joy + 1))
        missing_years = required_years - available_years
        if missing_years:
            print(f"Skipping Country {c}, JOY {joy} - Missing years: {sorted(missing_years)[:5]}")
            continue
        valid_for_joy.append(c)
    valid_combos[joy] = valid_for_joy
    print(f"JOY {joy}: {len(valid_for_joy)} valid countries")

print(f"\ngeo_dim: {geo_dim}")
print(f"Valid countries: {len(valid_countries)}")
print(f"Jump-off years: {JUMP_OFF_YEARS}")

JOY 1985: 36 valid countries
JOY 1990: 36 valid countries
JOY 1995: 37 valid countries
JOY 2000: 38 valid countries
JOY 2005: 39 valid countries
JOY 2010: 39 valid countries

geo_dim: 39
Valid countries: 39
Jump-off years: [1985, 1990, 1995, 2000, 2005, 2010]


In [26]:
predicted_rates = {}
observed_rates = {}

for joy in JUMP_OFF_YEARS:
    print(f"\n{'='*50}")
    print(f"Jump-off year: {joy}")
    print(f"{'='*50}")

    # 1. Training data: all valid countries, ages 15-44, year <= joy
    train_mask = (asfr_all[:, 1] <= joy) & np.isin(asfr_all[:, 0].astype(int), valid_countries)
    train_data = asfr_all[train_mask]
    print(f"Training rows: {train_data.shape[0]}")

    # 2. Validation data: valid countries, year > joy (up to YEAR_MAX)
    val_mask = (asfr_all[:, 1] > joy) & np.isin(asfr_all[:, 0].astype(int), valid_countries)
    val_data = asfr_all[val_mask]
    print(f"Validation rows: {val_data.shape[0]}")

    if train_data.shape[0] == 0:
        print("No training data, skipping")
        continue

    # 3. Scale steps_per_epoch
    steps_per_epoch = int(train_data.shape[0] * STEPS_RATIO / BATCH_SIZE)
    print(f"Steps per epoch: {steps_per_epoch}")

    # 4. Prep datasets
    train_prepped = training_functions.prep_data(train_data, mode="train", changeratetolog=False)
    val_prepped = training_functions.prep_data(val_data, mode="test", changeratetolog=False)

    # 5. Set seeds and train non-log model
    np.random.seed(42)
    tf.random.set_seed(42)
    random.seed(42)
    os.environ['PYTHONHASHSEED'] = str(42)

    model, val_loss = training_functions.run_deep_model(
        train_prepped, val_prepped, geo_dim,
        epochs=50,
        steps_per_epoch=steps_per_epoch,
        lograte=False
    )
    print(f"Best val loss: {val_loss:.6f}")

    # 6. Forecast grid: valid countries for this JOY, ages 15-44, years JOY+1 to JOY+FORECAST_LEN
    forecast_year_max = joy + FORECAST_LEN
    forecast_years = np.arange(joy + 1, forecast_year_max + 1)
    valid_c_list = sorted(valid_combos[joy])
    grid = np.array([(c, y, a) for c in valid_c_list for y in forecast_years for a in ages])
    print(f"Forecast grid: {grid.shape[0]} ({len(valid_c_list)} countries, {len(forecast_years)} years, {len(ages)} ages)")

    # 7. Predict using same normalization as training_functions.py
    forecast_features = (
        tf.convert_to_tensor((grid[:, 1] - YEAR_MIN) / (YEAR_MAX - YEAR_MIN), dtype=tf.float32),
        tf.convert_to_tensor(grid[:, 2], dtype=tf.float32),
        tf.convert_to_tensor(grid[:, 0], dtype=tf.float32),
    )
    preds = model.predict(forecast_features).flatten()

    # 8. Store predicted period rates
    pred_df = pd.DataFrame({
        'Country': grid[:, 0].astype(int),
        'Year': grid[:, 1].astype(int),
        'Age': grid[:, 2].astype(int),
        'Rate': preds,
    })
    predicted_rates[joy] = pred_df

    # 9. Store observed period rates for valid countries (year <= joy)
    valid_c_set = set(valid_c_list)
    obs_mask = np.isin(train_data[:, 0].astype(int), list(valid_c_set))
    obs_data = train_data[obs_mask]
    obs_df = pd.DataFrame({
        'Country': obs_data[:, 0].astype(int),
        'Year': obs_data[:, 1].astype(int),
        'Age': obs_data[:, 2].astype(int),
        'Rate': obs_data[:, 3],
    })
    observed_rates[joy] = obs_df

    # 10. Save model
    model.save(f"../models/dl_cohort_nonlog_joy{joy}.keras")
    print(f"Model saved: ../models/dl_cohort_nonlog_joy{joy}.keras")

print("\nAll jump-off years complete!")


Jump-off year: 1985
Training rows: 31890
Validation rows: 33960
Steps per epoch: 590
Epoch 1/50
590/590 - 3s - 5ms/step - loss: 0.0031 - val_loss: 9.6536e-04 - learning_rate: 0.0010
Epoch 2/50
590/590 - 2s - 3ms/step - loss: 8.4234e-04 - val_loss: 0.0011 - learning_rate: 0.0010
Epoch 3/50
590/590 - 2s - 3ms/step - loss: 6.3136e-04 - val_loss: 0.0012 - learning_rate: 0.0010
Epoch 4/50

Epoch 4: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.
590/590 - 2s - 3ms/step - loss: 4.2410e-04 - val_loss: 0.0011 - learning_rate: 0.0010
Epoch 5/50
590/590 - 1s - 3ms/step - loss: 3.5307e-04 - val_loss: 0.0010 - learning_rate: 2.5000e-04
Epoch 6/50
590/590 - 1s - 3ms/step - loss: 3.2680e-04 - val_loss: 0.0010 - learning_rate: 2.5000e-04
Epoch 7/50

Epoch 7: ReduceLROnPlateau reducing learning rate to 6.25000029685907e-05.
590/590 - 1s - 3ms/step - loss: 3.0278e-04 - val_loss: 9.9931e-04 - learning_rate: 2.5000e-04
Epoch 8/50
590/590 - 1s - 3ms/step - loss: 2.9098e-04 - val_loss: 

In [27]:
def period_to_cohort(period_df):
    df = period_df.copy()
    df['Year'] = df['Year'] - df['Age']  # cohort_birth_year = period_year - age
    return df

In [28]:
pred_cohort_dfs = []

for joy in JUMP_OFF_YEARS:
    valid_c_set = set(valid_combos[joy])

    # Observed period rates (year <= JOY), filtered to valid countries for this JOY
    obs_period = observed_rates[joy]

    # DL-predicted period rates (JOY+1 to JOY+30)
    pred_period = predicted_rates[joy]

    # Combine observed + predicted period rates
    combined_period = pd.concat([obs_period, pred_period], ignore_index=True)

    # Convert to cohort
    cohort_df = period_to_cohort(combined_period)

    # Add metadata columns
    cohort_df['JumpOffYear'] = joy
    cohort_df['Method'] = METHOD_NAME
    cohort_df['Key'] = cohort_df['Method'] + '_' + cohort_df['Country'].astype(str)

    pred_cohort_dfs.append(cohort_df)

predCASFR = pd.concat(pred_cohort_dfs, ignore_index=True)
print(f"predCASFR shape: {predCASFR.shape}")
print(f"Columns: {list(predCASFR.columns)}")
predCASFR.head()

predCASFR shape: (477210, 7)
Columns: ['Country', 'Year', 'Age', 'Rate', 'JumpOffYear', 'Method', 'Key']


Unnamed: 0,Country,Year,Age,Rate,JumpOffYear,Method,Key
0,0,1936,15,0.0015,1985,DL_NonLog,DL_NonLog_0
1,0,1935,16,0.00846,1985,DL_NonLog,DL_NonLog_0
2,0,1934,17,0.02615,1985,DL_NonLog,DL_NonLog_0
3,0,1933,18,0.05113,1985,DL_NonLog,DL_NonLog_0
4,0,1932,19,0.07838,1985,DL_NonLog,DL_NonLog_0


In [29]:
obs_cohort_dfs = []

# Full observed period data (ages 15-44, already filtered)
all_obs_df = pd.DataFrame({
    'Country': asfr_all[:, 0].astype(int),
    'Year': asfr_all[:, 1].astype(int),
    'Age': asfr_all[:, 2].astype(int),
    'Rate': asfr_all[:, 3],
})

for joy in JUMP_OFF_YEARS:
    valid_c_set = set(valid_combos[joy])

    # Filter to valid countries for this JOY (matching R's per-country obsCASFR)
    obs_filtered = all_obs_df[all_obs_df['Country'].isin(valid_c_set)]

    # Convert all observed period data to cohort
    cohort_df = period_to_cohort(obs_filtered)

    # Add metadata columns
    cohort_df['JumpOffYear'] = joy
    cohort_df['Method'] = METHOD_NAME
    cohort_df['Key'] = cohort_df['Method'] + '_' + cohort_df['Country'].astype(str)

    obs_cohort_dfs.append(cohort_df)

obsCASFR = pd.concat(obs_cohort_dfs, ignore_index=True)
print(f"obsCASFR shape: {obsCASFR.shape}")
print(f"Columns: {list(obsCASFR.columns)}")
obsCASFR.head()

obsCASFR shape: (390420, 7)
Columns: ['Country', 'Year', 'Age', 'Rate', 'JumpOffYear', 'Method', 'Key']


Unnamed: 0,Country,Year,Age,Rate,JumpOffYear,Method,Key
0,0,1936,15,0.0015,1985,DL_NonLog,DL_NonLog_0
1,0,1935,16,0.00846,1985,DL_NonLog,DL_NonLog_0
2,0,1934,17,0.02615,1985,DL_NonLog,DL_NonLog_0
3,0,1933,18,0.05113,1985,DL_NonLog,DL_NonLog_0
4,0,1932,19,0.07838,1985,DL_NonLog,DL_NonLog_0


In [31]:
predCASFR.to_csv('../data/dl_forecasts_cohort_fewer_ages.csv', index=False)
obsCASFR.to_csv('../data/dl_obs_cohort_fewer_ages.csv', index=False)

print(f"Saved: ../data/dl_forecasts_cohort_fewer_ages.csv ({predCASFR.shape[0]} rows)")
print(f"Saved: ../data/dl_obs_cohort_fewer_ages.csv ({obsCASFR.shape[0]} rows)")

Saved: ../data/dl_forecasts_cohort_fewer_ages.csv (477210 rows)
Saved: ../data/dl_obs_cohort_fewer_ages.csv (390420 rows)
