### Load dataset

In [None]:
import pandas as pd
from enum import StrEnum, auto
RANDOM_STATE = 42

# NOTE: StrEnum requires Python 3.11+
#       Can refactor to CONSTANTS instead
class FeatureVariant(StrEnum):
    LITERATURE = auto()
    RESEARCH = auto()
    STATISTICAL = auto()
    AUTOMATED = auto()

# Can be replaced with desired variant for different feature sets
GENE_FILE_VARIANT = FeatureVariant.RESEARCH

FILE_PATH = f"../Data/patient_genes_{GENE_FILE_VARIANT}.csv"

variant = 'svm'
df = pd.read_csv(FILE_PATH)

df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 977 entries, 0 to 976
Data columns (total 21 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   CACNA2D2  977 non-null    int64 
 1   ESR1      977 non-null    int64 
 2   AGR2      977 non-null    int64 
 3   GATA3     977 non-null    int64 
 4   SLC16A6   977 non-null    int64 
 5   TBC1D9    977 non-null    int64 
 6   INPP4B    977 non-null    int64 
 7   LDHB      977 non-null    int64 
 8   MLPH      977 non-null    int64 
 9   TSPAN1    977 non-null    int64 
 10  STBD1     977 non-null    int64 
 11  STARD3    977 non-null    int64 
 12  RARA      977 non-null    int64 
 13  MCCC2     977 non-null    int64 
 14  PSAT1     977 non-null    int64 
 15  MFGE8     977 non-null    int64 
 16  ANXA9     977 non-null    int64 
 17  PPP1R14C  977 non-null    int64 
 18  SLC44A4   977 non-null    int64 
 19  tnbc      977 non-null    bool  
 20  case_id   977 non-null    object
dtypes: bool(1), int6

### Imports

In [62]:
from sklearn.svm import SVC

%run "DataHelpers.ipynb"

### Dataset split: training and test data

In [63]:
X1, y1, X_train1, X_test1, y_train1, y_test1, test_case_ids1 = split_data(df, "tnbc", True)
print("\nApplied Smote")
X_smote, y_smote, X_train_smote, X_test_smote, y_train_smote, y_test_smote, test_case_ids_smote = split_data_apply_smote(df, "tnbc")

X_train.shape=(781, 19)
X_test.shape=(196, 19)
y_train.shape=(781,)
y_test.shape=(196,)

Applied Smote
X_train.shape=(1379, 19)
X_test.shape=(345, 19)
y_train.shape=(1379,)
y_test.shape=(345,)


### Support Vector Machine (SVM)

In [64]:
# Create model

model = SVC(random_state=RANDOM_STATE, probability=True)

def run_model(X_train: pd.DataFrame, X_test: pd.DataFrame, y_train: pd.Series, y_test: pd.Series, test_case_ids: pd.Series, is_smote: bool):
    # Train the model
    model.fit(X_train, y_train)

    # Model predictions
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)[:, 1]  # For ROC curves etc.

    # Save it in a dataframe, to CSV
    predictions = pd.DataFrame({
        "case_id": test_case_ids,
        "y_test": y_test,
        "y_pred": y_pred,
        "y_prob": y_prob
    })
    predictions.to_csv(f"../Data/model_output_{variant}_{GENE_FILE_VARIANT}_{'smote' if is_smote else ''}.csv", index=False)

    return y_pred, y_prob

In [65]:
y_pred1, y_prod1 = run_model(X_train1, X_test1, y_train1, y_test1, test_case_ids1, False)

print_evaluated_model_accuracy(y_test1, y_pred1)

Accuracy: 0.96


## Smote applied

In [66]:
y_pred_smote, y_prod_smote = run_model(X_train_smote, X_test_smote, y_train_smote, y_test_smote, test_case_ids_smote, True)

print_evaluated_model_accuracy(y_test_smote, y_pred_smote)

Accuracy: 0.94


## Model cross validation

In [67]:
def run_cross_validation(X: pd.DataFrame, y: pd.Series, y_test: pd.Series, y_pred: pd.Series, y_prob: pd.Series, is_smote: bool) -> pd.DataFrame:
    metrics: pd.DataFrame = get_cross_validation_metrics(model, X, y, cv=5)
    test_metrics = get_metrics(y_test, y_pred, y_prob)
    test_metrics["fold"] = 0 # Initial test metrics (before cross validation)
    test = pd.DataFrame([test_metrics])
    test.set_index("fold", inplace=True)

    print_validated_model_accuracy(model, metrics)

    # Prepend test_metrics to metrics dataframe, export and display
    metrics = pd.concat([test, metrics])
    metrics.to_csv(f"../Data/model_metrics_{variant}_{GENE_FILE_VARIANT}_{'smote' if is_smote else ''}.csv", index=False)
    return metrics

In [70]:
# Still getting warning:
#   ConvergenceWarning: lbfgs failed to converge (status=1):
#                       STOP: TOTAL NO. OF F,G EVALUATIONS EXCEEDS LIMIT.
metrics = run_cross_validation(X, y, y_test, y_pred, y_prod, False)
metrics

Model validation for SVC:
[0.9285714285714286, 0.9285714285714286, 0.9384615384615385, 0.9230769230769231, 0.9128205128205128]

Mean accuracy: 0.9263



Unnamed: 0_level_0,accuracy,recall,precision,f1_score,roc_auc,true_positive,true_negative,false_positive,false_negative
fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,0.943878,0.652174,0.833333,0.731707,0.975873,15,170,3,8
1,0.928571,0.608696,0.736842,0.666667,0.971852,14,168,5,9
2,0.928571,0.608696,0.736842,0.666667,0.966826,14,168,5,9
3,0.938462,0.565217,0.866667,0.684211,0.962588,13,170,2,10
4,0.923077,0.565217,0.722222,0.634146,0.966127,13,167,5,10
5,0.912821,0.521739,0.666667,0.585366,0.929474,12,166,6,11


In [69]:
# Still getting warning:
#   ConvergenceWarning: lbfgs failed to converge (status=1):
#                       STOP: TOTAL NO. OF F,G EVALUATIONS EXCEEDS LIMIT.
metric_smote = run_cross_validation(X_smote, y_smote, y_test_smote, y_pred_smote, y_prod_smote, True)
metric_smote

Model validation for SVC:
[0.9478260869565217, 0.9449275362318841, 0.9391304347826087, 0.927536231884058, 0.9563953488372093]

Mean accuracy: 0.9432



Unnamed: 0_level_0,accuracy,recall,precision,f1_score,roc_auc,true_positive,true_negative,false_positive,false_negative
fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,0.944928,0.97093,0.922652,0.946176,0.968746,167,159,14,5
1,0.947826,0.97093,0.927778,0.948864,0.961655,167,160,13,5
2,0.944928,0.959302,0.932203,0.945559,0.969956,165,161,12,7
3,0.93913,0.953757,0.926966,0.940171,0.975198,165,159,13,8
4,0.927536,0.942197,0.91573,0.928775,0.969485,163,157,15,10
5,0.956395,0.97093,0.943503,0.95702,0.979144,167,162,10,5
