In [1]:

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import pandas as pd
import os
from pathlib import Path
from dotenv import load_dotenv

load_dotenv(dotenv_path="/Users/sarah/Code/bioinformatics-tool/analysis/.env") 

def find_repo_root(start_path: Path = None) -> Path:
    if start_path is None:
        start_path = Path().resolve()
    for parent in [start_path] + list(start_path.parents):
        if (parent / '.git').exists():
            return parent
    raise RuntimeError("Could not find repo root!")


repo_root = find_repo_root()

print(repo_root)


expression_path = repo_root / os.getenv("EM_COLLAPSED")
clinical_path = repo_root / os.getenv("CLINICAL")

expression = pd.read_csv(expression_path, index_col=0)
clinical = pd.read_csv(clinical_path, index_col=0)

# Align samples
common_samples = expression.index.intersection(clinical.index)
X = expression.loc[common_samples]
clinical = clinical.loc[common_samples]

## Drop samples with NaN in ER status
mask = ~clinical["pgr status"].isnull()
X = X.loc[mask]
clinical = clinical.loc[mask]
y = clinical["pgr status"]

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)

# Train classifier
clf = LogisticRegression(max_iter=1000, class_weight="balanced")
clf.fit(X_train, y_train)

# Predict and evaluate
y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred, target_names=["pgr negative", "pgr positive"]))
print("Confusion matrix:\n", confusion_matrix(y_test, y_pred))

/Users/sarah/Code/bioinformatics-tool
              precision    recall  f1-score   support

pgr negative       0.72      0.61      0.66        77
pgr positive       0.94      0.96      0.95       511

    accuracy                           0.92       588
   macro avg       0.83      0.79      0.81       588
weighted avg       0.91      0.92      0.92       588

Confusion matrix:
 [[ 47  30]
 [ 18 493]]
