In [6]:
import pandas as pd
import joblib

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# --------------------
# Load data
# --------------------
df = pd.read_csv("../data/data.csv")

# Encode categorical column (as in notebook)
df["gender"] = pd.factorize(df["gender"])[0]

# Drop missing values
df = df.dropna()

# --------------------
# Split features and target
# --------------------
X = df.drop(columns=["sno", "target"])
y = df["target"]


X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    random_state=42,
    stratify=y
)

# --------------------
# Train model
# --------------------
model = LogisticRegression(
    solver="liblinear",
    max_iter=1000,
    class_weight="balanced"
)
model.fit(X_train, y_train)

# --------------------
# Evaluate
# --------------------
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)

print(f"Validation Accuracy: {accuracy:.4f}")

# --------------------
# Save model
# --------------------
joblib.dump(model, "../api/model.joblib")
print("Model saved as model.joblib")


Validation Accuracy: 0.8644
Model saved as model.joblib


In [7]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 293 entries, 0 to 302
Data columns (total 15 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   sno       293 non-null    int64  
 1   age       293 non-null    int64  
 2   gender    293 non-null    int64  
 3   cp        293 non-null    int64  
 4   trestbps  293 non-null    float64
 5   chol      293 non-null    float64
 6   fbs       293 non-null    int64  
 7   restecg   293 non-null    int64  
 8   thalach   293 non-null    float64
 9   exang     293 non-null    int64  
 10  oldpeak   293 non-null    float64
 11  slope     293 non-null    int64  
 12  ca        293 non-null    int64  
 13  thal      293 non-null    int64  
 14  target    293 non-null    object 
dtypes: float64(4), int64(10), object(1)
memory usage: 36.6+ KB


In [8]:
df.describe()

Unnamed: 0,sno,age,gender,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal
count,293.0,293.0,293.0,293.0,293.0,293.0,293.0,293.0,293.0,293.0,293.0,293.0,293.0,293.0
mean,150.877133,54.348123,0.31058,0.96587,131.679181,246.177474,0.146758,0.518771,149.880546,0.331058,1.054266,1.392491,0.730375,2.320819
std,86.860137,9.182042,0.463523,1.033114,17.658077,51.405545,0.35447,0.527162,22.638525,0.471399,1.173169,0.618946,1.029862,0.613331
min,0.0,29.0,0.0,0.0,94.0,126.0,0.0,0.0,71.0,0.0,0.0,0.0,0.0,0.0
25%,76.0,47.0,0.0,0.0,120.0,211.0,0.0,0.0,134.0,0.0,0.0,1.0,0.0,2.0
50%,151.0,55.0,0.0,1.0,130.0,240.0,0.0,1.0,152.0,0.0,0.8,1.0,0.0,2.0
75%,226.0,61.0,1.0,2.0,140.0,275.0,0.0,1.0,167.0,1.0,1.8,2.0,1.0,3.0
max,302.0,77.0,1.0,3.0,200.0,564.0,1.0,2.0,202.0,1.0,6.2,2.0,4.0,3.0
