In [1]:
import numpy as np
import pandas as pd
import sklearn
import matplotlib.pyplot as plt

from sklearn.tree import DecisionTreeRegressor

In [2]:
# Load Data
# load measurements.csv
measurements = pd.read_csv('../data/pop_scale/pop_scale_segmentations_measurements/measurements.csv')

# make a "genotype" column
measurements["image_name"] = [name.split('-')[2] if len(name.split('-')) >= 3 else 'na' for name in measurements["image_name"]]
measurements.rename(columns={"image_name": "genotype"}, inplace=True)
measurements = measurements[measurements["genotype"] != 'na']
measurements.head()

Unnamed: 0,genotype,seed_count,wing_area,env_area,seed_area,wing_perimeter,env_perimeter,seed_perimeter,wing_to_total_area,env_to_total_area,...,env_B,seed_r,seed_g,seed_b,seed_h,seed_s,seed_v,seed_l,seed_a,seed_B
0,2022.W_0.png,10.0,1.093654,0.330268,0.161344,1123.685425,643.026479,1161.643794,0.689887,0.208336,...,129.462098,36.445407,33.678663,34.873153,112.331334,23.542486,36.819282,34.121071,129.593691,127.578252
1,2022.W_1.png,8.0,0.964245,0.337849,0.136344,1074.271211,622.884343,919.678282,0.670342,0.234872,...,130.52753,35.791535,33.298239,33.369479,72.296793,24.167192,36.068743,33.487119,129.200841,128.339642
2,2022.W_2.png,9.0,1.04222,0.320465,0.155053,1095.744299,626.256926,1033.259018,0.686693,0.211147,...,130.071972,36.896209,34.795423,35.691516,94.988904,20.732779,37.416089,35.295192,129.201687,127.66528
3,2022.W_3.png,10.0,1.082148,0.363458,0.158028,1148.472222,662.683333,1086.229581,0.67481,0.226646,...,128.426014,37.857337,35.582445,37.32513,124.449762,21.882626,38.629168,36.34645,129.463484,127.156725
4,2022.W_4.png,10.0,1.046019,0.336182,0.162276,1105.38391,648.825469,1132.957503,0.677264,0.217667,...,128.720987,36.846383,33.564329,35.227057,122.581005,26.9619,37.334953,34.189619,129.95185,127.382551


In [3]:
# get mean measurements for each genotype
avg_measurements = measurements.groupby("genotype").mean(numeric_only=True)
avg_measurements.reset_index(inplace=True)
avg_measurements.head()

Unnamed: 0,genotype,seed_count,wing_area,env_area,seed_area,wing_perimeter,env_perimeter,seed_perimeter,wing_to_total_area,env_to_total_area,...,env_B,seed_r,seed_g,seed_b,seed_h,seed_s,seed_v,seed_l,seed_a,seed_B
0,15_06,10.352941,0.840818,0.38975,0.151808,1091.562955,687.59344,1124.124854,0.607532,0.282279,...,135.995502,35.031684,35.689299,28.337407,35.72181,59.424325,37.249197,35.105414,126.071655,132.832866
1,19_1,10.066667,1.114924,0.438166,0.167156,1219.73674,714.119076,1147.218758,0.64787,0.254641,...,135.213489,29.962345,32.515918,28.051981,50.108914,39.295631,32.84259,30.925064,125.77342,130.649978
2,2022.W copy_0.png,8.0,0.970769,0.336541,0.118512,1081.702669,619.854906,845.050865,0.680849,0.236033,...,135.978966,29.007258,37.455467,28.191744,58.651898,66.210041,37.455467,35.293966,121.828671,132.976259
3,2022.W copy_1.png,8.0,0.934729,0.319139,0.11077,1054.18795,598.399062,820.90873,0.684965,0.233863,...,134.833558,33.005824,42.349458,34.392008,64.56156,58.872836,42.349458,41.105485,121.909885,131.858275
4,2022.W copy_2.png,9.0,0.812346,0.287365,0.112831,985.134126,584.156421,879.820418,0.669953,0.236994,...,136.003056,32.080686,41.284784,31.495235,58.840375,63.087992,41.284784,39.647236,121.52033,133.094187


In [4]:
# load yield data
yield_data = pd.read_csv('../data/external/external_phenotypes.csv')
avg_yield_data = yield_data.groupby("genotype").mean(numeric_only=True)

# only keep yield_kg column
avg_yield_data = avg_yield_data[["yield_kg"]]
avg_yield_data.reset_index(inplace=True)
avg_yield_data.head()

Unnamed: 0,genotype,yield_kg
0,15_06,0.664974
1,18_01,0.641097
2,19_1,0.374994
3,7_10_#134,0.609996
4,950011,0.483615


In [5]:
# inner join the two dataframes
total_pheno_avg = pd.merge(avg_yield_data, avg_measurements, 
                           on="genotype")
print(len(total_pheno_avg))
total_pheno_avg.set_index("genotype", inplace=True)
total_pheno_avg.head()

213


Unnamed: 0_level_0,yield_kg,seed_count,wing_area,env_area,seed_area,wing_perimeter,env_perimeter,seed_perimeter,wing_to_total_area,env_to_total_area,...,env_B,seed_r,seed_g,seed_b,seed_h,seed_s,seed_v,seed_l,seed_a,seed_B
genotype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
15_06,0.664974,10.352941,0.840818,0.38975,0.151808,1091.562955,687.59344,1124.124854,0.607532,0.282279,...,135.995502,35.031684,35.689299,28.337407,35.72181,59.424325,37.249197,35.105414,126.071655,132.832866
19_1,0.374994,10.066667,1.114924,0.438166,0.167156,1219.73674,714.119076,1147.218758,0.64787,0.254641,...,135.213489,29.962345,32.515918,28.051981,50.108914,39.295631,32.84259,30.925064,125.77342,130.649978
950011,0.483615,10.266667,0.798652,0.382284,0.142907,1082.08921,676.844385,1087.752962,0.604148,0.287792,...,133.988947,32.416516,33.561788,29.94033,50.742347,36.879412,34.65014,32.623727,126.662987,130.278667
950016,0.541404,10.0,0.732712,0.316003,0.135651,1022.715733,614.708032,1014.212986,0.618639,0.266817,...,133.64552,29.289101,29.866962,28.133646,53.64944,23.774866,30.376187,28.318812,127.326269,129.130386
950021,0.343528,10.333333,0.87891,0.316569,0.14804,1101.089756,631.41619,1072.837767,0.654103,0.235886,...,134.480838,31.186241,30.211567,28.227778,39.772896,28.011206,31.57352,29.126821,128.034366,129.535674


In [11]:
# multiply yield by 1000 to get yield in grams
total_pheno_avg["yield_g"] = total_pheno_avg["yield_kg"] * 1000
total_pheno_avg.drop(columns=["yield_kg"], inplace=True)

In [46]:
# # normalize the data
# from sklearn.preprocessing import StandardScaler
# scaler = StandardScaler()
# scaler.fit(total_pheno_avg)
# total_pheno_avg_scaled = scaler.transform(total_pheno_avg)
# total_pheno_avg_scaled = pd.DataFrame(total_pheno_avg_scaled, columns=total_pheno_avg.columns, index=total_pheno_avg.index)
# total_pheno_avg_scaled.head()

Unnamed: 0_level_0,yield_kg,seed_count,wing_area,env_area,seed_area,wing_perimeter,env_perimeter,seed_perimeter,wing_to_total_area,env_to_total_area,...,env_B,seed_r,seed_g,seed_b,seed_h,seed_s,seed_v,seed_l,seed_a,seed_B
genotype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
15_06,0.079992,0.477866,-0.136294,0.736954,0.652534,0.02575,0.644561,0.719894,-1.218017,1.213487,...,1.340919,1.893548,0.969179,0.152861,-1.614092,1.455825,1.356347,1.161173,-0.077904,1.489663
19_1,-1.20487,0.144442,1.833864,1.605359,1.641253,1.701122,1.224758,0.954218,0.546797,-0.20538,...,0.628265,-0.298927,0.026743,-0.09455,-0.10162,-0.100477,-0.120549,-0.018978,-0.246954,0.028414
950011,-0.723585,0.377382,-0.439371,0.603036,0.07911,-0.098082,0.409446,0.350843,-1.366045,1.496501,...,-0.487669,0.762495,0.337347,1.542295,-0.035029,-0.287293,0.48526,0.46057,0.257283,-0.220145
950016,-0.46753,0.066795,-0.913322,-0.585806,-0.388294,-0.874159,-0.949667,-0.395337,-0.732052,0.419724,...,-0.800636,-0.590103,-0.759949,-0.023762,0.270585,-1.300504,-0.947176,-0.754747,0.633254,-0.988817
950021,-1.344291,0.455028,0.137491,-0.575649,0.409825,0.150276,-0.584208,0.199505,0.819472,-1.168148,...,-0.039405,0.230405,-0.657608,0.057833,-1.188214,-0.972961,-0.545884,-0.526638,1.034628,-0.717512


In [12]:
# find correlations for each feature to yield kg
corrs = total_pheno_avg.corr()
corrs["yield_g"].sort_values(ascending=False)
print("Top Correlations:")
print(corrs["yield_g"].sort_values(ascending=False)[1:11])

Top Correlations:
seed_to_env_area           0.082369
wing_to_total_perimeter    0.056347
env_a                      0.053173
wing_to_seed_perimeter     0.052769
seed_to_total_area         0.052388
seed_h                     0.049383
wing_b                     0.047669
wing_to_env_area           0.043743
wing_a                     0.042077
seed_a                     0.041274
Name: yield_g, dtype: float64


In [13]:
print("Bottom Correlations:")
print(corrs["yield_g"].sort_values(ascending=True)[:10])

Bottom Correlations:
seed_count         -0.115457
env_area           -0.110886
env_B              -0.102648
env_perimeter      -0.096926
env_s              -0.096797
seed_perimeter     -0.095298
seed_area          -0.084809
env_to_seed_area   -0.083180
wing_area          -0.080501
seed_s             -0.078685
Name: yield_g, dtype: float64


In [15]:
# Prepare Data: Shapes should be x: (n_samples, n_features), y: (n_samples, 1)
x = total_pheno_avg.drop("yield_g", axis=1)
y = total_pheno_avg["yield_g"]

# Train Test Split
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

In [16]:
# train decision tree
# options
max_depth = 5

# train - look up sklearn documentation for DecisionTreeRegressor
dt = DecisionTreeRegressor(max_depth=max_depth)
dt.fit(x_train, y_train)

# predict on test set
y_pred = dt.predict(x_test)

# evaluate with r2 score
from sklearn.metrics import r2_score
r2 = r2_score(y_test, y_pred)
print(f"R2 Score: {r2}, MSE: {np.mean((y_test - y_pred)**2)}")

# print tree
tree = sklearn.tree.export_text(dt, feature_names=x.columns)
print(tree)

R2 Score: -0.554756465692603, MSE: 75586.42900333129
|--- env_B <= 135.40
|   |--- seed_count <= 11.47
|   |   |--- env_b <= 40.38
|   |   |   |--- env_area <= 0.45
|   |   |   |   |--- env_area <= 0.42
|   |   |   |   |   |--- value: [690.46]
|   |   |   |   |--- env_area >  0.42
|   |   |   |   |   |--- value: [414.24]
|   |   |   |--- env_area >  0.45
|   |   |   |   |--- env_a <= 123.84
|   |   |   |   |   |--- value: [1043.80]
|   |   |   |   |--- env_a >  123.84
|   |   |   |   |   |--- value: [966.33]
|   |   |--- env_b >  40.38
|   |   |   |--- value: [107.16]
|   |--- seed_count >  11.47
|   |   |--- seed_to_env_area <= 0.40
|   |   |   |--- seed_perimeter <= 1303.37
|   |   |   |   |--- env_to_total_perimeter <= 0.23
|   |   |   |   |   |--- value: [533.51]
|   |   |   |   |--- env_to_total_perimeter >  0.23
|   |   |   |   |   |--- value: [625.23]
|   |   |   |--- seed_perimeter >  1303.37
|   |   |   |   |--- seed_h <= 66.20
|   |   |   |   |   |--- value: [438.15]
|   |   