### Load dataset

In [3]:
import pandas as pd

%run "DataHelpers.ipynb"

# Can be replaced with desired variant for different feature sets
GENE_FILE_VARIANT = FeatureVariant.RESEARCHPAPERS # For values, see FeatureVariant.print_info()
variant = ModelVariant.LG                      # For values, see ModelVariant.print_info()

FILE_PATH = f"../Data/patient_genes_{GENE_FILE_VARIANT}.csv"

df = pd.read_csv(FILE_PATH)

df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 977 entries, 0 to 976
Data columns (total 33 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   BRCA1    977 non-null    float64
 1   BRCA2    977 non-null    float64
 2   CD274    977 non-null    float64
 3   MKI67    977 non-null    float64
 4   PDCD1    977 non-null    float64
 5   PIK3CA   977 non-null    float64
 6   TP53     977 non-null    float64
 7   LRPPRC   977 non-null    float64
 8   YOD1     977 non-null    float64
 9   DCLK1    977 non-null    float64
 10  TOP2A    977 non-null    float64
 11  TACSTD2  977 non-null    float64
 12  ROR1     977 non-null    float64
 13  TTN      977 non-null    float64
 14  CTLA4    977 non-null    float64
 15  EGFR     977 non-null    float64
 16  EPCAM    977 non-null    float64
 17  MYC      977 non-null    float64
 18  PTEN     977 non-null    float64
 19  CDK6     977 non-null    float64
 20  DDX3X    977 non-null    float64
 21  SRC      977 non

### Import model function

In [4]:
from sklearn.linear_model import LogisticRegression

### Dataset split: training and test data

In [5]:
X, y, X_train, X_test, y_train, y_test, test_case_ids = split_data(df, "tnbc", True)
print("\nApplied Smote")
X_smote, y_smote, X_train_smote, X_test_smote, y_train_smote, y_test_smote, test_case_ids_smote = split_data_apply_smote(df, "tnbc")

X_train.shape=(781, 31)
X_test.shape=(196, 31)
y_train.shape=(781,)
y_test.shape=(196,)

Applied Smote
X_train.shape=(1379, 31)
X_test.shape=(345, 31)
y_train.shape=(1379,)
y_test.shape=(345,)


### Logistic Regression

In [6]:
# Create model
# Bumping max_iter to a higer number than the default 100, MAY resolve the following warning
#       ConvergenceWarning: lbfgs failed to converge (status=1): STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.
# TODO: Look into scaling the data

model = LogisticRegression(random_state=RANDOM_STATE, solver='lbfgs', max_iter=100_000)


def run_model(X_train: pd.DataFrame, X_test: pd.DataFrame, y_train: pd.Series, y_test: pd.Series, test_case_ids, is_smote: bool):
    # Train the model
    model.fit(X_train, y_train)

    # Model predictions
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)[:, 1]  # For ROC curves etc.

    # Save it in a dataframe, to CSV
    predictions = pd.DataFrame({
        "case_id": test_case_ids,
        "y_test": y_test,
        "y_pred": y_pred,
        "y_prob": y_prob
    })
    predictions.to_csv(f"../Data/model_output_{variant}_{GENE_FILE_VARIANT}{'_smote' if is_smote else ''}.csv", index=False)

    return y_pred, y_prob

In [7]:
y_pred, y_prod = run_model(X_train, X_test, y_train, y_test, test_case_ids, False)

print_evaluated_model_accuracy(y_test, y_pred)

Accuracy: 0.95


## Smote applied

In [8]:
y_pred_smote, y_prod_smote = run_model(X_train_smote, X_test_smote, y_train_smote, y_test_smote, test_case_ids_smote, True)

print_evaluated_model_accuracy(y_test_smote, y_pred_smote)

Accuracy: 0.96


## Model cross validation

In [9]:
def run_cross_validation(X: pd.DataFrame, y: pd.Series, y_test: pd.Series, y_pred: pd.Series, y_prob: pd.Series, is_smote: bool) -> pd.DataFrame:
    metrics = get_cross_validation_metrics(model, X, y, cv=5)
    test_metrics = get_metrics(y_test, y_pred, y_prob)
    test_metrics["fold"] = 0 # Initial test metrics (before cross validation)
    test = pd.DataFrame([test_metrics])
    test.set_index("fold", inplace=True)

    print_validated_model_accuracy(model, metrics)

    # Prepend test_metrics to metrics dataframe, export and display
    metrics = pd.concat([test, metrics])
    metrics.to_csv(f"../Data/model_metrics_{variant}_{GENE_FILE_VARIANT}{'_smote' if is_smote else ''}.csv", index=False)
    return metrics

In [10]:
# Still getting warning:
#   ConvergenceWarning: lbfgs failed to converge (status=1):
#                       STOP: TOTAL NO. OF F,G EVALUATIONS EXCEEDS LIMIT.
metrics = run_cross_validation(X, y, y_test, y_pred, y_prod, False)
metrics

Model validation for LogisticRegression:
[0.9336734693877551, 0.9540816326530612, 0.9435897435897436, 0.9333333333333333, 0.9282051282051282]

Mean accuracy: 0.9386



Unnamed: 0_level_0,accuracy,recall,precision,f1_score,roc_auc,true_positive,true_negative,false_positive,false_negative
fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,0.94898,0.869565,0.740741,0.8,0.9809,20,166,7,3
1,0.933673,0.652174,0.75,0.697674,0.970847,15,168,5,8
2,0.954082,0.782609,0.818182,0.8,0.968585,18,169,4,5
3,0.94359,0.73913,0.772727,0.755556,0.956775,17,167,5,6
4,0.933333,0.695652,0.727273,0.711111,0.938322,16,166,6,7
5,0.928205,0.652174,0.714286,0.681818,0.935288,15,166,6,8


In [11]:
# Still getting warning:
#   ConvergenceWarning: lbfgs failed to converge (status=1):
#                       STOP: TOTAL NO. OF F,G EVALUATIONS EXCEEDS LIMIT.
metric_smote = run_cross_validation(X_smote, y_smote, y_test_smote, y_pred_smote, y_prod_smote, True)
metric_smote

Model validation for LogisticRegression:
[0.9420289855072463, 0.9478260869565217, 0.936231884057971, 0.9420289855072463, 0.9418604651162791]

Mean accuracy: 0.9420



Unnamed: 0_level_0,accuracy,recall,precision,f1_score,roc_auc,true_positive,true_negative,false_positive,false_negative
fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,0.956522,0.97093,0.943503,0.95702,0.98155,167,163,10,5
1,0.942029,0.959302,0.926966,0.942857,0.971434,165,160,13,7
2,0.947826,0.953488,0.942529,0.947977,0.974358,164,163,10,8
3,0.936232,0.924855,0.946746,0.935673,0.975669,160,163,9,13
4,0.942029,0.959538,0.927374,0.943182,0.961856,166,159,13,7
5,0.94186,0.94186,0.94186,0.94186,0.969105,162,162,10,10
