In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.inspection import permutation_importance
import joblib

DATA_PATH = "updated_pollution_dataset.csv"
MODEL_PATH = "knn_pollution_model.joblib"
PLOTS_DIR = "plots"

os.makedirs(PLOTS_DIR, exist_ok=True)
sns.set(style="whitegrid")

def load_dataset(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Dataset not found at {path}. Please provide a CSV.")
    df = pd.read_csv(path)
    print("Loaded dataset with shape:", df.shape)
    return df

def auto_select_features(df):
    candidate_features = [
        "PM2.5", "PM2_5", "PM25", "PM10", "NO2", "NO_2", "SO2", "SO_2",
        "CO", "O3", "O_3", "AQI", "Humidity", "Temperature", "RH", "Temp"
    ]
    lower_map = {c.lower(): c for c in df.columns}
    chosen_features = []
    for cand in candidate_features:
        key = cand.lower()
        if key in lower_map:
            chosen_features.append(lower_map[key])
    if len(chosen_features) == 0:
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        exclude = [c for c in numeric_cols if 'id' in c.lower() or 'time' in c.lower() or 'date' in c.lower()]
        chosen_features = [c for c in numeric_cols if c not in exclude]
    return chosen_features

def prepare_label(df):
    possible_label_names = ["Category", "category", "Label", "label", "PollutionLevel", "pollution_level", "AQI_Category"]
    for name in possible_label_names:
        if name in df.columns:
            print("Found label column:", name)
            return df, name

    aqi_candidates = [c for c in df.columns if c.lower() == "aqi"]
    if len(aqi_candidates) > 0:
        aqi_col = aqi_candidates[0]
        print("Using AQI column to create categories:", aqi_col)
        def aqi_to_cat(a):
            if pd.isna(a):
                return np.nan
            a = float(a)
            if a <= 50:
                return 0
            elif a <= 100:
                return 1
            elif a <= 150:
                return 2
            elif a <= 200:
                return 3
            else:
                return 4
        df["Category"] = df[aqi_col].apply(aqi_to_cat)
        return df, "Category"

    pm25_candidates = [c for c in df.columns if c.lower() in ("pm2.5", "pm2_5", "pm25", "pm2.5 (Âµg/m3)")]
    if len(pm25_candidates) > 0:
        pm25 = pm25_candidates[0]
        print("Creating Category using PM2.5 thresholds on column:", pm25)
        def pm25_to_cat(v):
            if pd.isna(v):
                return np.nan
            v = float(v)
            if v <= 30:
                return 0
            elif v <= 60:
                return 1
            elif v <= 90:
                return 2
            elif v <= 120:
                return 3
            else:
                return 4
        df["Category"] = df[pm25].apply(pm25_to_cat)
        return df, "Category"

    raise ValueError("Could not determine or create a label column. Ensure CSV has 'Category' or 'AQI' or PM2.5.")

def build_and_train(X_train, y_train, param_grid=None):
    pipeline = Pipeline([
        ("imputer", SimpleImputer(strategy="mean")),
        ("scaler", StandardScaler()),
        ("knn", KNeighborsClassifier())
    ])

    if param_grid is None:
        param_grid = {
            "knn__n_neighbors": list(range(1, 16, 2)),
            "knn__weights": ["uniform", "distance"],
            "knn__p": [1, 2]
        }

    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    grid = GridSearchCV(pipeline, param_grid, cv=cv, scoring="accuracy", n_jobs=-1, verbose=1, return_train_score=True)
    grid.fit(X_train, y_train)
    print("Best params:", grid.best_params_)
    print("Best CV accuracy:", grid.best_score_)
    return grid.best_estimator_, grid

def plot_correlation_heatmap(X, outpath=os.path.join(PLOTS_DIR, "correlation_heatmap.png")):
    corr = X.corr()
    plt.figure(figsize=(10,8))
    sns.heatmap(corr, annot=True, fmt=".2f", cmap="vlag", square=True)
    plt.title("Feature Correlation Heatmap")
    plt.tight_layout()
    plt.savefig(outpath, dpi=300)
    plt.close()
    print("Saved correlation heatmap to", outpath)

def plot_feature_distributions(X, y=None, outpath=os.path.join(PLOTS_DIR, "feature_distributions.png")):
    n_cols = 3
    n_features = X.shape[1]
    n_rows = int(np.ceil(n_features / n_cols))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*5, n_rows*3.5))
    axes = axes.flatten()
    for i, col in enumerate(X.columns):
        ax = axes[i]
        sns.histplot(X[col].dropna(), kde=True, ax=ax)
        ax.set_title(col)
    # remove unused axes
    for j in range(n_features, len(axes)):
        axes[j].axis('off')
    plt.tight_layout()
    plt.savefig(outpath, dpi=300)
    plt.close()
    print("Saved feature distributions to", outpath)

def plot_pairplot(X, y, max_features=6, outpath=os.path.join(PLOTS_DIR, "pairplot.png")):
    # limit to a few features to keep plot readable
    use_cols = list(X.columns[:max_features])
    df_small = pd.concat([X[use_cols], pd.Series(y, name="label")], axis=1)
    sns.pairplot(df_small, hue="label", corner=True, plot_kws={"alpha":0.6, "s":30})
    plt.suptitle("Pairplot (first {} features)".format(len(use_cols)), y=1.02)
    plt.savefig(outpath, dpi=300)
    plt.close()
    print("Saved pairplot to", outpath)

def plot_pca(X, y, outpath_exp=os.path.join(PLOTS_DIR, "pca_explained_variance.png"),
             outpath_2d=os.path.join(PLOTS_DIR, "pca_2d.png")):
    imputed = SimpleImputer(strategy="mean").fit_transform(X)
    scaled = StandardScaler().fit_transform(imputed)
    pca = PCA(n_components=min(10, scaled.shape[1]))
    comps = pca.fit_transform(scaled)
    # explained variance plot
    plt.figure(figsize=(6,4))
    plt.plot(np.arange(1, len(pca.explained_variance_ratio_)+1), np.cumsum(pca.explained_variance_ratio_)*100, marker='o')
    plt.xlabel("Number of components")
    plt.ylabel("Cumulative explained variance (%)")
    plt.title("PCA cumulative explained variance")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(outpath_exp, dpi=300)
    plt.close()
    # 2D scatter
    if comps.shape[1] >= 2:
        plt.figure(figsize=(7,6))
        sns.scatterplot(x=comps[:,0], y=comps[:,1], hue=y, palette="tab10", alpha=0.8)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title("PCA 2D projection colored by class")
        plt.legend(title="label", bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(outpath_2d, dpi=300)
        plt.close()
        print("Saved PCA plots to", outpath_exp, "and", outpath_2d)
    else:
        print("Not enough PCA components for 2D plot; saved explained variance only.")

def plot_tsne(X, y, outpath=os.path.join(PLOTS_DIR, "tsne_2d.png")):
    try:
        imputed = SimpleImputer(strategy="mean").fit_transform(X)
        scaled = StandardScaler().fit_transform(imputed)
        tsne = TSNE(n_components=2, random_state=42, init="pca", learning_rate="auto")
        embed = tsne.fit_transform(scaled)
        plt.figure(figsize=(7,6))
        sns.scatterplot(x=embed[:,0], y=embed[:,1], hue=y, palette="tab10", alpha=0.8)
        plt.title("t-SNE 2D embedding")
        plt.legend(title="label", bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(outpath, dpi=300)
        plt.close()
        print("Saved t-SNE plot to", outpath)
    except Exception as e:
        print("t-SNE plotting failed:", e)

def plot_grid_search_results(grid, outpath=os.path.join(PLOTS_DIR, "grid_search_results.png")):
    # Attempt to plot mean test score for different n_neighbors (aggregated)
    try:
        res = pd.DataFrame(grid.cv_results_)
        # Extract n_neighbors if present in param_knn__n_neighbors
        if 'param_knn__n_neighbors' in res.columns:
            summary = res.groupby('param_knn__n_neighbors')['mean_test_score'].mean().reset_index()
            plt.figure(figsize=(6,4))
            plt.plot(summary['param_knn__n_neighbors'], summary['mean_test_score'], marker='o')
            plt.xlabel("n_neighbors")
            plt.ylabel("Mean CV accuracy")
            plt.title("CV accuracy vs n_neighbors")
            plt.grid(True)
            plt.tight_layout()
            plt.savefig(outpath, dpi=300)
            plt.close()
            print("Saved grid search results to", outpath)
        else:
            print("GridSearch results do not contain 'param_knn__n_neighbors', skipping grid plot.")
    except Exception as e:
        print("Grid search plotting failed:", e)

def plot_confusion_matrix(cm, labels, outpath=os.path.join(PLOTS_DIR, "confusion_matrix.png")):
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig(outpath, dpi=300)
    plt.close()
    print("Saved confusion matrix to", outpath)

def plot_roc_multiclass(model, X_test, y_test, outpath=os.path.join(PLOTS_DIR, "roc_curves.png")):
    # Only possible if probability estimates available
    try:
        if not hasattr(model, "predict_proba"):
            print("Model doesn't support predict_proba; skipping ROC curves.")
            return
        imputed = SimpleImputer(strategy="mean").fit_transform(X_test)
        scaled = StandardScaler().fit_transform(imputed)
        y_score = model.predict_proba(X_test)
        classes = np.unique(y_test)
        n_classes = len(classes)
        # Binarize labels
        y_bin = label_binarize(y_test, classes=classes)
        plt.figure(figsize=(8,6))
        fprs = []
        tprs = []
        aucs = []
        for i in range(n_classes):
            fpr, tpr, _ = roc_curve(y_bin[:,i], y_score[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw=2, label=f"Class {classes[i]} (AUC = {roc_auc:.2f})")
            fprs.append(fpr); tprs.append(tpr); aucs.append(roc_auc)
        # micro-average
        y_bin_flat = y_bin.ravel()
        y_score_flat = y_score.ravel()
        # skip micro/macro for now if shapes mismatch
        plt.plot([0,1], [0,1], linestyle='--', color='grey')
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title("Multiclass ROC Curves")
        plt.legend(loc="lower right", bbox_to_anchor=(1.05, 0))
        plt.tight_layout()
        plt.savefig(outpath, dpi=300, bbox_inches='tight')
        plt.close()
        print("Saved ROC curves to", outpath)
    except Exception as e:
        print("ROC plotting failed:", e)

def plot_permutation_importance(model, X_val, y_val, outpath=os.path.join(PLOTS_DIR, "permutation_importance.png"), n_repeats=10):
    try:
        # Need a pipeline to pass through the same preprocessing used for training if model is pipeline
        from sklearn.inspection import permutation_importance
        result = permutation_importance(model, X_val, y_val, n_repeats=n_repeats, random_state=42, n_jobs=-1)
        sorted_idx = result.importances_mean.argsort()[::-1]
        names = np.array(X_val.columns)[sorted_idx]
        means = result.importances_mean[sorted_idx]
        stds = result.importances_std[sorted_idx]
        plt.figure(figsize=(8, max(4, 0.4 * len(names))))
        plt.barh(names, means, xerr=stds)
        plt.xlabel("Permutation importance (mean decrease in score)")
        plt.title("Permutation Feature Importance")
        plt.gca().invert_yaxis()
        plt.tight_layout()
        plt.savefig(outpath, dpi=300)
        plt.close()
        print("Saved permutation importance to", outpath)
    except Exception as e:
        print("Permutation importance failed:", e)

def evaluate_and_plot(model, X_test, y_test, features):
    # Prediction + basic metrics
    y_pred = model.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    print("Test Accuracy:", acc)
    print("Classification report:\n", classification_report(y_test, y_pred))
    cm = confusion_matrix(y_test, y_pred)
    # save confusion matrix plot
    labels = np.unique(pd.concat([y_test, pd.Series(y_pred)]))
    plot_confusion_matrix(cm, labels)
    # ROC (if available)
    try:
        plot_roc_multiclass(model, X_test, y_test)
    except Exception as e:
        print("ROC plotting exception:", e)
    return acc, cm

def main():
    df = load_dataset(DATA_PATH)
    chosen_features = auto_select_features(df)
    print("Automatically selected features:", chosen_features)
    if len(chosen_features) < 2:
        raise ValueError("Too few numeric features detected.")

    df, label_col = prepare_label(df)
    print("Using label column:", label_col)
    print(df[label_col].value_counts(dropna=False))

    df = df.dropna(subset=[label_col]).reset_index(drop=True)

    features = [f for f in chosen_features if f != label_col and f in df.columns]
    X = df[features].copy()
    y = df[label_col].astype(int).copy()

    print("Feature columns used:", features)
    print("Target distribution:\n", y.value_counts())

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42, stratify=y)
    print("Train shape:", X_train.shape, "Test shape:", X_test.shape)

    # Train
    best_model, grid = build_and_train(X_train, y_train)

    # Evaluate & plot confusion matrix + ROC
    acc, cm = evaluate_and_plot(best_model, X_test, y_test, features)

    # Additional plots for paper
    try:
        plot_correlation_heatmap(X)
    except Exception as e:
        print("Correlation heatmap error:", e)

    try:
        plot_feature_distributions(X, y)
    except Exception as e:
        print("Feature distribution error:", e)

    try:
        plot_pairplot(X, y, max_features=6)
    except Exception as e:
        print("Pairplot error:", e)

    try:
        plot_pca(X, y)
    except Exception as e:
        print("PCA plotting error:", e)

    try:
        plot_tsne(X, y)
    except Exception as e:
        print("t-SNE/UMAP error:", e)

    try:
        plot_grid_search_results(grid)
    except Exception as e:
        print("Grid search plot error:", e)

    try:
        plot_permutation_importance(best_model, X_test, y_test)
    except Exception as e:
        print("Permutation importance error:", e)

    # Save model
    joblib.dump(best_model, MODEL_PATH)
    print(f"Saved trained model to {MODEL_PATH}")

    # Example predictions (median / 95th percentile)
    median_vals = X.median().to_dict()
    high_vals = {k: (X[k].quantile(0.95) if np.isfinite(X[k].quantile(0.95)) else X[k].max()) for k in X.columns}

    def predict_pollution(sample_dict):
        sample = pd.DataFrame([sample_dict], columns=features)
        pred = best_model.predict(sample)[0]
        meaning = {0: "Good", 1: "Moderate", 2: "Poor", 3: "Very Poor", 4: "Severe"}.get(pred, str(pred))
        return {"predicted_category": int(pred), "meaning": meaning}

    print("\nExample prediction (median values):")
    print(median_vals)
    print(predict_pollution(median_vals))

    print("\nExample prediction (high pollution sample - 95th percentile):")
    print(high_vals)
    print(predict_pollution(high_vals))

if __name__ == "__main__":
    main()


Loaded dataset with shape: (5000, 10)
Automatically selected features: ['PM2.5', 'PM10', 'NO2', 'SO2', 'CO', 'Humidity', 'Temperature']
Creating Category using PM2.5 thresholds on column: PM2.5
Using label column: Category
Category
0    3956
1     713
2     215
3      68
4      48
Name: count, dtype: int64
Feature columns used: ['PM2.5', 'PM10', 'NO2', 'SO2', 'CO', 'Humidity', 'Temperature']
Target distribution:
 Category
0    3956
1     713
2     215
3      68
4      48
Name: count, dtype: int64
Train shape: (4000, 7) Test shape: (1000, 7)
Fitting 5 folds for each of 32 candidates, totalling 160 fits
Best params: {'knn__n_neighbors': 7, 'knn__p': 1, 'knn__weights': 'distance'}
Best CV accuracy: 0.96275
Test Accuracy: 0.966
Classification report:
               precision    recall  f1-score   support

           0       0.98      0.99      0.99       791
           1       0.91      0.87      0.89       143
           2       0.88      0.88      0.88        43
           3       0.92  