In [None]:
# to be added in EDA_and_Data_Curation.ipynb after data cleaning cell
# calculate weights for combined loss func
# get df of just beta values 
def get_loss_weights(df):
    '''
    Calculates weights for loss function terms based on inverse variance of beta values.

    Parameters
    ----------
    df : Pandas DataFrame that includes the columns "AD_beta", "PD_beta", "FTD_beta", "ALS_beta"

    Returns
    -------
    ad_wt, pd_wt, ftd_wt, als_wt : weights for each component of the combined loss function
    '''

    betas = df[["AD_beta", "PD_beta", "FTD_beta", "ALS_beta"]]

    # get variance of beta values for each disease 
    variances_by_disease = betas.var()

    # calculate inverse variance to use as weight
    ad_wt = 1/variances_by_disease["AD_beta"]
    pd_wt = 1/variances_by_disease["PD_beta"]
    ftd_wt = 1/variances_by_disease["FTD_beta"]
    als_wt = 1/variances_by_disease["ALS_beta"]

    return ad_wt, pd_wt, ftd_wt, als_wt

In [None]:
# to be added in def_train() in train.py
# Total loss is the sum of the individual task losses
total_loss = get_loss_weights(features_df)[0]*loss_ad + get_loss_weights(features_df)[1]*loss_pd + get_loss_weights(features_df)[2]*loss_ftd + get_loss_weights(features_df)[3]*loss_als

In [None]:
# hyperparameter optimization using itertools grid search
# to be added to main() in train.py

from sklearn.model_selection import GridSearchCV
import itertools

# define parameter space (learning rate, weight decay, dropout rate, hidden channels, attn heads)
# these are based on the original settings in train.py
param_grid = {
    'Learning_Rate' : [0.01, 0.005, 0.001], 
    'Hidden_Channels' : [64, 128],
    'Attention_Heads': [4, 8],
    'Dropout_Rate' : [0.3, 0.6],
    'Weight_Decay' : [5e-4, 5e-3]
    }

base_settings = {'Epochs': 300} # based on original value in train.py
best_loss = float('inf') # initialize with value of +infinity
best_params = None

for values in itertools.product(param_grid.values()):
    point = dict(zip(param_grid.keys(), values))
    settings = {base_settings, point}

    val_loss = evaluate(model, data, data.val_mask) # calculate validation loss

    if val_loss < best_loss: # update best loss and params
        best_loss = val_loss
        best_params = settings

print(f"Best params: {best_params}")
print(f"Best val loss: {best_loss}")