In [1]:
import pandas as pd
import numpy as np
from sklearn.experimental import enable_iterative_imputer
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.impute import IterativeImputer
from sklearn.model_selection import RandomizedSearchCV

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

import shap

from sklearn.linear_model import LinearRegression

import matplotlib.pyplot as plt
import xgboost as xgb

from longitudinal.settings.constants import DATA_PATH

gen1_train = pd.read_csv(DATA_PATH + "gen1_train_comp_final.csv")  # parent data (training)
gen2_train = pd.read_csv(DATA_PATH + "gen2_train_comp_final.csv")  # child data (training)

gen1_test = pd.read_csv(DATA_PATH + "gen1_test_comp_final.csv")    # parent data (test)
gen2_test = pd.read_csv(DATA_PATH + "gen2_test_upto9_comp_final.csv")  # child data, up to age 9

In [2]:
gen1_train.head()

Unnamed: 0,gen1_id,sex_assigned_at_birth,age,SHgt_cm
0,774,F,0.1,56.961812
1,774,F,0.25,64.82619
2,774,F,0.5,74.340764
3,774,F,0.75,79.747338
4,774,F,1.0,84.092569


In [3]:
gen2_train.head()

Unnamed: 0,gen2_id,sex_assigned_at_birth,study_parent_sex,study_parent_id_new,AgeGr,SHgt_cm,Wgt_kg
0,3012,M,mother,636,0.1,56.251625,4.636903
1,3012,M,mother,636,0.25,64.491579,
2,3012,M,mother,636,0.5,70.465927,
3,3012,M,mother,636,0.75,73.992677,
4,3012,M,mother,636,1.0,79.343537,


In [4]:
gen2_test

Unnamed: 0,gen2_id,sex_assigned_at_birth,study_parent_sex,study_parent_id_new,AgeGr,SHgt_cm,Wgt_kg
0,2831,F,mother,455,0.10,52.912025,
1,2831,F,mother,455,0.25,59.532779,
2,2831,F,mother,455,0.50,67.733527,
3,2831,F,mother,455,0.75,70.450677,
4,2831,F,mother,455,1.00,74.991937,
...,...,...,...,...,...,...,...
1227,2332,M,mother,274,5.00,103.870670,19.938306
1228,2332,M,mother,274,6.00,109.936726,22.661975
1229,2332,M,mother,274,7.00,116.386523,25.626890
1230,2332,M,mother,274,8.00,120.052957,26.943480


In [5]:
gen2_train[gen2_train["gen2_id"] == 3012].interpolate(method="polynomial", order=2)

  gen2_train[gen2_train["gen2_id"] == 3012].interpolate(method="polynomial", order=2)


Unnamed: 0,gen2_id,sex_assigned_at_birth,study_parent_sex,study_parent_id_new,AgeGr,SHgt_cm,Wgt_kg
0,3012,M,mother,636,0.1,56.251625,4.636903
1,3012,M,mother,636,0.25,64.491579,6.022203
2,3012,M,mother,636,0.5,70.465927,7.66482
3,3012,M,mother,636,0.75,73.992677,9.564754
4,3012,M,mother,636,1.0,79.343537,11.722006
5,3012,M,mother,636,1.5,86.061664,14.136575
6,3012,M,mother,636,2.0,92.497582,16.808461
7,3012,M,mother,636,3.0,101.012604,19.737665
8,3012,M,mother,636,4.0,111.528387,22.924187
9,3012,M,mother,636,5.0,118.245241,26.368025


In [6]:
gen1_train = gen1_train.rename(columns={"age": "AgeGr"})
gen1_test = gen1_test.rename(columns={"age": "AgeGr"})

gen1_train['sex_assigned_at_birth'] = gen1_train['sex_assigned_at_birth'].map({'M': 1, 'F': 0})
gen1_test['sex_assigned_at_birth'] = gen1_test['sex_assigned_at_birth'].map({'M': 1, 'F': 0})

gen2_train['sex_assigned_at_birth'] = gen2_train['sex_assigned_at_birth'].map({'M': 1, 'F': 0})
gen2_test['sex_assigned_at_birth'] = gen2_test['sex_assigned_at_birth'].map({'M': 1, 'F': 0})

gen2_train['study_parent_sex'] = gen2_train['study_parent_sex'].map({'mother': 1, 'father': 0})
gen2_test['study_parent_sex'] = gen2_test['study_parent_sex'].map({'mother': 1, 'father': 0})

In [7]:
# Define a non-negative estimator
non_negative_estimator = HistGradientBoostingRegressor(loss="poisson", random_state=0)

gen1_imputer = IterativeImputer(estimator=non_negative_estimator, max_iter=10, random_state=0)
gen1_train_imputed = pd.DataFrame(gen1_imputer.fit_transform(gen1_train), columns=gen1_train.columns)
gen1_test_imputed = pd.DataFrame(gen1_imputer.transform(gen1_test), columns=gen1_test.columns)

gen2_imputer = IterativeImputer(estimator=non_negative_estimator, max_iter=10, random_state=0)
gen2_train_imputed = pd.DataFrame(gen2_imputer.fit_transform(gen2_train), columns=gen2_train.columns)
gen2_test_imputed = pd.DataFrame(gen2_imputer.transform(gen2_test), columns=gen2_test.columns)

gen1_train_imputed.head()



Unnamed: 0,gen1_id,sex_assigned_at_birth,AgeGr,SHgt_cm
0,774.0,0.0,0.1,56.961812
1,774.0,0.0,0.25,64.82619
2,774.0,0.0,0.5,74.340764
3,774.0,0.0,0.75,79.747338
4,774.0,0.0,1.0,84.092569


In [8]:
gen2_train_imputed.head()

Unnamed: 0,gen2_id,sex_assigned_at_birth,study_parent_sex,study_parent_id_new,AgeGr,SHgt_cm,Wgt_kg
0,3012.0,1.0,1.0,636.0,0.1,56.251625,4.636903
1,3012.0,1.0,1.0,636.0,0.25,64.491579,16.11514
2,3012.0,1.0,1.0,636.0,0.5,70.465927,16.11514
3,3012.0,1.0,1.0,636.0,0.75,73.992677,16.11514
4,3012.0,1.0,1.0,636.0,1.0,79.343537,16.11514


In [9]:
gen2_train_imputed[gen2_train_imputed["AgeGr"] > 9.0].groupby(["gen2_id"]).agg({"AgeGr": "count"}).reset_index()["AgeGr"].unique()

array([8])

In [10]:
y = gen2_train_imputed[gen2_train_imputed["AgeGr"] > 9.0]
gen2_train_imputed_vals = gen2_train_imputed[gen2_train_imputed["AgeGr"] <= 9.0]
gen2_train_imputed_vals.head()

Unnamed: 0,gen2_id,sex_assigned_at_birth,study_parent_sex,study_parent_id_new,AgeGr,SHgt_cm,Wgt_kg
0,3012.0,1.0,1.0,636.0,0.1,56.251625,4.636903
1,3012.0,1.0,1.0,636.0,0.25,64.491579,16.11514
2,3012.0,1.0,1.0,636.0,0.5,70.465927,16.11514
3,3012.0,1.0,1.0,636.0,0.75,73.992677,16.11514
4,3012.0,1.0,1.0,636.0,1.0,79.343537,16.11514


In [11]:
gen2_pivot = gen2_train_imputed_vals.pivot(index = "gen2_id", columns=["AgeGr"], values=["SHgt_cm", "Wgt_kg"]).reset_index()

In [12]:
{gen2_id: gender for gen2_id, gender in zip(gen2_train_imputed_vals.gen2_id.unique(), gen2_train_imputed_vals.sex_assigned_at_birth.unique())}

{np.float64(3012.0): np.float64(1.0), np.float64(2830.0): np.float64(0.0)}

In [13]:
gen2_pivot = gen2_train_imputed_vals.pivot(index = "gen2_id", columns=["AgeGr"], values=["SHgt_cm", "Wgt_kg"]).reset_index()

temp = gen2_train_imputed_vals[["gen2_id", "sex_assigned_at_birth", "study_parent_sex", "study_parent_id_new"]].drop_duplicates()

gen2_pivot["sex_assigned_at_birth"] = gen2_pivot["gen2_id"].map({
    gen2_id: gender for gen2_id, gender in zip(temp.gen2_id, temp.sex_assigned_at_birth)
})
gen2_pivot["study_parent_sex"] = gen2_pivot["gen2_id"].map({
    gen2_id: gender for gen2_id, gender in zip(temp.gen2_id, temp.study_parent_sex)
})
gen2_pivot["study_parent_id_new"] = gen2_pivot["gen2_id"].map({
    gen2_id: id for gen2_id, id in zip(temp.gen2_id, temp.study_parent_id_new)
})

gen2_pivot.head()

Unnamed: 0_level_0,gen2_id,SHgt_cm,SHgt_cm,SHgt_cm,SHgt_cm,SHgt_cm,SHgt_cm,SHgt_cm,SHgt_cm,SHgt_cm,...,Wgt_kg,Wgt_kg,Wgt_kg,Wgt_kg,Wgt_kg,Wgt_kg,Wgt_kg,sex_assigned_at_birth,study_parent_sex,study_parent_id_new
AgeGr,Unnamed: 1_level_1,0.1,0.25,0.5,0.75,1.0,1.5,2.0,3.0,4.0,...,3.0,4.0,5.0,6.0,7.0,8.0,9.0,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,1332.0,54.098866,59.128432,67.256015,72.511388,74.88631,85.617761,91.007144,99.977795,109.291388,...,16.53866,17.585813,18.839206,20.709612,27.404934,28.566989,35.309511,0.0,0.0,724.0
1,2330.0,55.798251,61.940215,66.362172,69.582869,75.85507,80.153012,87.153373,93.533053,103.741448,...,16.418458,16.418458,16.204488,17.52323,19.673407,22.308756,25.653254,0.0,1.0,262.0
2,2331.0,57.920127,63.560415,70.718045,74.80765,79.118851,84.527483,89.661203,98.039558,103.095061,...,16.410179,16.410179,18.465756,20.533493,25.11748,27.049411,30.743697,1.0,1.0,274.0
3,2505.0,58.894508,62.822452,68.207623,72.622024,78.192172,85.216314,91.741224,100.252722,105.917235,...,16.914932,16.914932,20.744253,22.510856,27.159028,27.762489,33.13732,0.0,0.0,668.0
4,2507.0,54.436597,61.947728,68.443083,72.281629,74.854322,83.284637,90.811808,101.297097,110.018713,...,15.888589,17.201003,19.710666,22.20888,27.242827,30.970568,31.616559,0.0,0.0,350.0


In [14]:
gen1_pivot = gen1_train_imputed.pivot(index = "gen1_id", columns="AgeGr", values="SHgt_cm").reset_index()
gen1_pivot.head()

AgeGr,gen1_id,0.1,0.25,0.5,0.75,1.0,1.5,2.0,3.0,4.0,...,15.5,16.0,16.5,17.0,17.5,18.0,18.5,19.0,19.5,20.0
0,370.0,56.040541,62.560859,70.33854,76.673608,78.063317,86.018694,92.239281,101.812521,110.926569,...,172.871774,173.087439,173.093158,173.680642,173.717748,174.034011,173.378159,174.233099,173.395964,174.418909
1,371.0,57.418376,63.772686,69.596115,73.796708,76.127006,83.112889,88.70752,99.160057,106.209678,...,171.450506,172.261825,172.379268,172.646257,172.961566,172.20089,172.761957,172.075126,172.902702,173.100697
2,375.0,55.15319,63.371686,70.434241,74.452849,78.720473,84.513188,90.652232,105.983143,111.769232,...,181.163703,182.084762,183.101623,184.756931,184.706826,185.558144,185.759924,185.86422,185.278418,185.922876
3,376.0,54.528201,62.132331,68.23877,72.414316,75.119201,82.596449,88.046795,95.624714,102.118513,...,170.828557,173.213424,174.765932,176.794776,178.219085,177.999918,178.414874,179.618422,179.373829,180.724155
4,377.0,58.371971,65.624997,73.877182,78.443835,83.473532,89.733324,97.713124,107.311409,113.037642,...,189.379195,191.230623,191.842735,193.767378,194.11339,194.778597,194.873554,194.531218,193.187627,194.928493


In [15]:
# Flatten MultiIndex columns
gen2_pivot.columns = ['_'.join(map(str, col)).strip() if isinstance(col, tuple) else col for col in gen2_pivot.columns]

# Reset index if needed
gen2_pivot = gen2_pivot.reset_index(drop=True)
gen2_pivot

Unnamed: 0,gen2_id_,SHgt_cm_0.1,SHgt_cm_0.25,SHgt_cm_0.5,SHgt_cm_0.75,SHgt_cm_1.0,SHgt_cm_1.5,SHgt_cm_2.0,SHgt_cm_3.0,SHgt_cm_4.0,...,Wgt_kg_3.0,Wgt_kg_4.0,Wgt_kg_5.0,Wgt_kg_6.0,Wgt_kg_7.0,Wgt_kg_8.0,Wgt_kg_9.0,sex_assigned_at_birth_,study_parent_sex_,study_parent_id_new_
0,1332.0,54.098866,59.128432,67.256015,72.511388,74.886310,85.617761,91.007144,99.977795,109.291388,...,16.538660,17.585813,18.839206,20.709612,27.404934,28.566989,35.309511,0.0,0.0,724.0
1,2330.0,55.798251,61.940215,66.362172,69.582869,75.855070,80.153012,87.153373,93.533053,103.741448,...,16.418458,16.418458,16.204488,17.523230,19.673407,22.308756,25.653254,0.0,1.0,262.0
2,2331.0,57.920127,63.560415,70.718045,74.807650,79.118851,84.527483,89.661203,98.039558,103.095061,...,16.410179,16.410179,18.465756,20.533493,25.117480,27.049411,30.743697,1.0,1.0,274.0
3,2505.0,58.894508,62.822452,68.207623,72.622024,78.192172,85.216314,91.741224,100.252722,105.917235,...,16.914932,16.914932,20.744253,22.510856,27.159028,27.762489,33.137320,0.0,0.0,668.0
4,2507.0,54.436597,61.947728,68.443083,72.281629,74.854322,83.284637,90.811808,101.297097,110.018713,...,15.888589,17.201003,19.710666,22.208880,27.242827,30.970568,31.616559,0.0,0.0,350.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
187,2825.0,54.033054,63.544369,71.610702,77.332729,81.406418,88.032766,93.448435,103.088308,109.620688,...,15.283372,16.639492,20.289783,22.426626,25.919419,28.366638,27.774494,1.0,1.0,570.0
188,2827.0,54.059864,62.715721,69.375560,73.427055,76.093888,82.336733,86.024939,95.646433,103.449957,...,16.849228,16.849228,17.741160,21.173636,23.658468,27.754112,30.471963,1.0,1.0,744.0
189,2829.0,50.388464,58.598732,66.118497,71.914895,75.420543,82.315794,87.624148,98.262270,104.978502,...,15.454158,15.454158,17.981339,20.101313,22.220640,24.507504,27.064617,0.0,0.0,662.0
190,2830.0,55.222683,59.726152,66.428978,71.006008,76.178826,83.018114,87.720740,97.476010,107.817763,...,16.693587,16.693587,20.813682,27.477743,30.263063,32.425335,36.964074,0.0,1.0,712.0


In [16]:
gen2_train_merged = gen2_pivot.merge(
    gen1_pivot,
    left_on='study_parent_id_new_', 
    right_on='gen1_id',            
    suffixes=('_child', '_parent')
)
gen2_train_merged.head()

Unnamed: 0,gen2_id_,SHgt_cm_0.1,SHgt_cm_0.25,SHgt_cm_0.5,SHgt_cm_0.75,SHgt_cm_1.0,SHgt_cm_1.5,SHgt_cm_2.0,SHgt_cm_3.0,SHgt_cm_4.0,...,15.5,16.0,16.5,17.0,17.5,18.0,18.5,19.0,19.5,20.0
0,2509.0,58.340287,64.892308,69.87229,74.645393,77.889181,85.099498,90.794792,100.800558,109.603019,...,184.158103,185.352121,186.178159,186.566052,187.010304,187.320811,185.662955,187.292054,185.152801,188.107978
1,2510.0,60.23063,66.481798,72.892351,75.839962,79.379582,85.720324,94.149611,102.482859,111.842058,...,192.47303,194.305717,194.453334,194.755102,194.676736,194.970999,193.632844,195.213834,193.705696,195.714755
2,2513.0,52.931457,60.6703,67.65772,72.429313,77.864908,84.589526,88.423925,96.628471,103.67685,...,172.912806,175.870532,179.276514,180.730152,182.14176,182.991174,183.055302,184.77805,184.384969,185.051548
3,2514.0,59.747495,67.777098,73.293031,78.117079,81.205328,89.44596,96.569437,105.724538,114.28734,...,173.165838,174.747713,179.356033,182.232486,186.96584,188.761179,189.732545,191.797827,192.42812,193.183434
4,2515.0,57.874021,66.271674,71.975038,74.358996,81.091844,87.847264,92.586788,106.509702,113.283472,...,180.230267,184.040152,186.571063,187.859403,189.477404,189.772037,189.939614,190.50281,191.041194,191.686232


In [17]:
gen1_train_imputed.shape, gen2_train_imputed.shape, gen2_train_merged.shape

((3636, 4), (4224, 7), (151, 69))

In [18]:
y = y.pivot(index="gen2_id", columns="AgeGr", values="SHgt_cm")

In [19]:
y = y.reset_index()
y

AgeGr,gen2_id,10.0,11.0,12.0,13.0,14.0,15.0,16.0,18.0
0,1332.0,145.262229,150.353801,160.222405,164.272504,168.277689,168.095242,168.110161,167.920992
1,2330.0,131.799105,147.909535,159.677325,166.801102,169.204540,168.833029,168.950776,168.892500
2,2331.0,141.524124,150.015631,161.792710,171.815312,175.473667,181.793574,182.969942,183.312296
3,2505.0,139.257516,144.366273,150.663565,163.729845,167.539623,168.214010,168.072106,167.654046
4,2507.0,150.337479,156.747461,164.874826,166.797806,168.598731,169.725980,169.760017,171.565722
...,...,...,...,...,...,...,...,...,...
187,2825.0,143.574362,150.936887,157.004982,165.647166,174.381178,181.339468,182.456204,184.051217
188,2827.0,137.226438,147.350383,151.265462,155.343264,165.374217,170.255790,172.711099,174.151818
189,2829.0,141.468590,150.291349,160.823434,164.816989,166.242011,167.438765,168.077352,168.980405
190,2830.0,145.422500,152.627457,156.193087,158.433856,159.909045,159.748067,159.855986,163.199817


In [20]:
X = gen2_train_merged.sort_values(by = ["gen2_id_"])
y = y[y["gen2_id"].isin(X["gen2_id_"])].sort_values(by = ["gen2_id"]).drop(columns=["gen2_id"])
X = X.drop(columns=["study_parent_id_new_", "gen2_id_"])

In [21]:
y

AgeGr,10.0,11.0,12.0,13.0,14.0,15.0,16.0,18.0
5,147.289220,155.924073,164.750772,170.208080,171.386509,172.078449,171.959384,171.900070
6,151.265047,155.998662,161.689019,167.173962,176.680109,183.780534,189.187459,189.128872
8,137.312339,143.764537,147.231769,152.734873,161.902362,167.298785,166.513565,166.456130
9,147.781505,154.520668,157.834000,162.072648,168.498437,181.159295,180.499212,186.264141
10,150.670882,157.367574,163.940360,168.924856,171.338386,173.033062,172.445247,172.804436
...,...,...,...,...,...,...,...,...
186,143.097049,148.590110,156.629707,160.636276,163.913734,164.856109,166.266527,166.083892
187,143.574362,150.936887,157.004982,165.647166,174.381178,181.339468,182.456204,184.051217
189,141.468590,150.291349,160.823434,164.816989,166.242011,167.438765,168.077352,168.980405
190,145.422500,152.627457,156.193087,158.433856,159.909045,159.748067,159.855986,163.199817


In [22]:
X

Unnamed: 0,SHgt_cm_0.1,SHgt_cm_0.25,SHgt_cm_0.5,SHgt_cm_0.75,SHgt_cm_1.0,SHgt_cm_1.5,SHgt_cm_2.0,SHgt_cm_3.0,SHgt_cm_4.0,SHgt_cm_5.0,...,15.5,16.0,16.5,17.0,17.5,18.0,18.5,19.0,19.5,20.0
0,58.340287,64.892308,69.872290,74.645393,77.889181,85.099498,90.794792,100.800558,109.603019,114.060892,...,184.158103,185.352121,186.178159,186.566052,187.010304,187.320811,185.662955,187.292054,185.152801,188.107978
1,60.230630,66.481798,72.892351,75.839962,79.379582,85.720324,94.149611,102.482859,111.842058,117.804483,...,192.473030,194.305717,194.453334,194.755102,194.676736,194.970999,193.632844,195.213834,193.705696,195.714755
2,52.931457,60.670300,67.657720,72.429313,77.864908,84.589526,88.423925,96.628471,103.676850,108.883130,...,172.912806,175.870532,179.276514,180.730152,182.141760,182.991174,183.055302,184.778050,184.384969,185.051548
3,59.747495,67.777098,73.293031,78.117079,81.205328,89.445960,96.569437,105.724538,114.287340,120.079372,...,173.165838,174.747713,179.356033,182.232486,186.965840,188.761179,189.732545,191.797827,192.428120,193.183434
4,57.874021,66.271674,71.975038,74.358996,81.091844,87.847264,92.586788,106.509702,113.283472,119.793981,...,180.230267,184.040152,186.571063,187.859403,189.477404,189.772037,189.939614,190.502810,191.041194,191.686232
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
146,53.989574,60.674943,67.938858,73.079333,77.316115,84.708157,90.235990,100.565127,107.803022,115.342704,...,165.949538,166.490693,165.969672,166.596265,165.978469,165.712481,165.069240,165.950971,164.923126,166.019612
147,54.033054,63.544369,71.610702,77.332729,81.406418,88.032766,93.448435,103.088308,109.620688,117.115971,...,169.353936,169.022755,169.543562,169.533395,169.085862,169.085862,169.363957,169.941494,169.161917,169.584448
148,50.388464,58.598732,66.118497,71.914895,75.420543,82.315794,87.624148,98.262270,104.978502,110.794429,...,174.413205,176.049775,177.211263,176.959935,179.080076,177.561546,179.189076,177.082916,179.037537,177.526634
149,55.222683,59.726152,66.428978,71.006008,76.178826,83.018114,87.720740,97.476010,107.817763,112.990150,...,165.538833,166.333153,165.768158,166.031561,166.115044,165.811419,166.283725,166.097378,166.239095,166.062266


In [23]:
from xgboost import XGBRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.model_selection import train_test_split

# Assuming X is your input features and Y contains 9 future height values per sample
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Wrap XGBoost with MultiOutputRegressor
model = MultiOutputRegressor(XGBRegressor(objective="reg:squarederror", n_estimators=200))

# Train
model.fit(X_train, y_train)

# Predict
y_pred = model.predict(X_test)

In [24]:
# Compute errors across all outputs
mae = mean_absolute_error(y_test, y_pred, multioutput='uniform_average')  # MAE across all outputs
mse = mean_squared_error(y_test, y_pred, multioutput='uniform_average')   # MSE across all outputs
rmse = np.sqrt(mse)  # RMSE
r2 = r2_score(y_test, y_pred, multioutput='uniform_average')  # R² score

# Print results
print(f"Mean Absolute Error (MAE): {mae:.4f}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4f}")
print(f"R² Score: {r2:.4f}")

Mean Absolute Error (MAE): 2.9268
Root Mean Squared Error (RMSE): 3.9066
R² Score: 0.4396
