In [None]:
import os
import numpy as np
import pandas as pd

from joblib                import dump
from sklearn.impute        import SimpleImputer
from sklearn.linear_model  import LogisticRegressionCV, LogisticRegression
from sklearn.preprocessing import StandardScaler

In [None]:
disease_codes = {'diabetes': 'HC221','renal_failure': 'HC294','gout': 'HC328','myocardial_infarction': 'HC326',
                 'asthma' : 'HC382','gall_stones': 'HC188','ulcerative_colitis': 'HC201','peripheral_vascular_disease': 'HC385',
                 'atrial_flutter': 'HC440','osteoarthritis': 'HC376','arthritis_(nos)': 'HC78','TTE_cystitis': 'HC1313',
                 'TTE_chronic_renal_failure': 'HC1302','TTE_psoriasis': 'HC1159','TTE_cellulitis': 'HC1139','TTE_cholelithiasis': 'HC1125',
                 'glaucoma': 'HC276','Blood_clot_or_DVT_diagnosed_by_doctor': 'BIN_FC11006152','skin_cancer': 'cancer1003'}

## Load data

In [None]:
%%time

meta  = pd.read_csv('data/meta.csv')
prs   = pd.read_csv('data/prs.csv')
pheno = pd.read_csv('data/pheno.csv')

## Only have one column to split the data

In [None]:
prs['final_split']   = prs.split_nonWB.fillna(prs.split)
pheno['final_split'] = pheno.split_nonWB.fillna(prs.split)

In [None]:
prs = prs.drop([f"Global_PC{i+1}" for i in range(40)], axis=1)
prs = prs.drop([f"PC{i+1}" for i in range(40)], axis=1)
prs = prs.drop(['split_nonWB','IID','population','age','age0','age1','age2','age3','sex','BMI','N_CNV','LEN_CNV','Array'], axis=1)

In [None]:
diseases = list(disease_codes.values())

# Fit WB models

In [None]:
def cv_train_model(X_train, y_train):
    
    cv = LogisticRegressionCV(
        Cs=5, 
        penalty='l1', 
        cv=2,
        class_weight='balanced', 
        solver='saga', 
        verbose=1,
        max_iter=300
    )
    
    cv.fit(X_train, y_train)
            
    model = LogisticRegression(
        penalty='l1', 
        class_weight='balanced', 
        C=cv.C_[0],
        solver='saga', 
        verbose=0,
        max_iter=300
    )

    model.fit(X_train, y_train)

    return model

In [None]:
def mean_fill_and_scale(df):
    
    imp    = SimpleImputer(missing_values=np.nan, strategy='mean')
    scaler = StandardScaler()    
    columns = df.columns

    df = imp.fit_transform(df)
    df = scaler.fit_transform(df)
    df = pd.DataFrame(df, columns=columns)

    return df

## Fit models

In [None]:
def fit_and_save_models(model_dir_path, population, disease_list, prs, meta, pheno):
    """
    Fit models for each disease and save the model if not already exists.
    
    Parameters:
    - model_dir_path (str): Directory path for saving/loading the models.
    - population (str): Population type.
    - disease_list (list): List of diseases.
    - prs (DataFrame): DataFrame containing PRS data.
    - meta (DataFrame): Metadata.
    - pheno (DataFrame): Phenotype data.
    """
    
    for disease in disease_list:
        model_path = os.path.join(model_dir_path, disease + '.joblib')

        if os.path.isfile(model_path):
            print('Already fitted model for disease ' + disease)
        else:
            print('Fitting model for disease ' + disease)

            y = pheno[pheno.population == population]
            y_train = y[y.final_split == 'train'][disease]

            # Check to use PRS data or not
            if prs is not None:
                X = prs[['PRS_' + disease, 'final_split']].join(meta, how='inner')
                print('Using PRS data')
            else:
                X = prs[['final_split']].join(meta, how='inner')
                
            X = X[X.population == population]
            X = X.drop(['IID', 'population', 'age0', 'age1', 'age2', 'age3'], axis=1)

            X_train = X[X.final_split == 'train'].drop(['final_split', 'split'], axis=1)  
            X_train = mean_fill_and_scale(X_train)

            print('Started to fit model')
            model = cv_train_model(X_train, y_train) 
            print('Fitted model')
            dump(model, model_path)

In [None]:
%%time

fit_and_save_models('models/log_reg/WB_with_PRS/', 'white_british', diseases, prs, meta, pheno)