In [3]:
import os, sys, json, csv, pickle
from pathlib import Path
import collections
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.efficientnet import preprocess_input

# ----------------------------
# CONFIG
# ----------------------------
BASE_DIR   = Path(r"C:\Users\sagni\Downloads\MRI Scan")
MODEL_H5   = BASE_DIR / "model.h5"
CLASS_PKL  = BASE_DIR / "class_indices.pkl"
DEFAULT_INPUT = Path(r"C:\Users\sagni\Downloads\MRI Scan\archive\Testing")

IMG_SIZE   = (256, 256)
TOP_K      = 3
ANNOTATE   = True
ANN_DIR    = BASE_DIR / "annotated_predictions"

JSON_OUT   = BASE_DIR / "predictions.json"
CSV_OUT    = BASE_DIR / "predictions.csv"
SUMMARY_JSON = BASE_DIR / "summary.json"
BAR_PNG    = BASE_DIR / "class_counts.png"

# ----------------------------
# Utils
# ----------------------------
def ensure_artifacts():
    if not MODEL_H5.exists():
        raise FileNotFoundError(f"Missing model: {MODEL_H5}")
    if not CLASS_PKL.exists():
        raise FileNotFoundError(f"Missing class map: {CLASS_PKL}")

def load_class_indices(pkl_path: Path):
    with open(pkl_path, "rb") as f:
        class_indices = pickle.load(f)
    idx_to_class = {v: k for k, v in class_indices.items()}
    ordered_classes = [idx_to_class[i] for i in range(len(idx_to_class))]
    return idx_to_class, ordered_classes

def list_images(path: Path):
    exts = {".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"}
    if path.is_file():
        return [path] if path.suffix.lower() in exts else []
    return sorted([p for p in path.rglob("*") if p.suffix.lower() in exts])

def load_tensor(img_path: Path):
    img = Image.open(img_path).convert("RGB").resize(IMG_SIZE)
    arr = np.array(img).astype(np.float32)
    arr = preprocess_input(arr)
    return np.expand_dims(arr, axis=0)

def annotate_image(img_path: Path, text: str, out_path: Path):
    img = Image.open(img_path).convert("RGB")
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("arial.ttf", 20)
    except:
        font = ImageFont.load_default()
    # banner
    draw.rectangle([(0,0),(img.width,30)], fill="black")
    draw.text((5,5), text, font=font, fill="white")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    img.save(out_path)

def plot_counts_bar(class_counts, out_png: Path):
    if not class_counts:
        return
    labels, values = zip(*class_counts.items())
    plt.figure(figsize=(6,4))
    plt.bar(labels, values)
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()

def parse_input_path(argv) -> Path:
    """
    Robustly parse input path from argv.
    - Supports: positional path or --input PATH
    - Ignores Jupyter's '-f <kernel.json>' pair
    - Falls back to DEFAULT_INPUT if nothing valid
    """
    args = list(argv)  # copy
    # strip Jupyter's -f <file>
    cleaned = []
    skip_next = False
    for i, a in enumerate(args):
        if skip_next:
            skip_next = False
            continue
        if a == "-f":
            skip_next = True
            continue
        cleaned.append(a)

    # look for --input PATH
    if "--input" in cleaned:
        try:
            idx = cleaned.index("--input")
            cand = Path(cleaned[idx+1])
            if cand.exists():
                return cand
        except Exception:
            pass

    # first remaining positional, if exists & valid
    for a in cleaned:
        if a.startswith("-"):
            continue
        cand = Path(a)
        if cand.exists():
            return cand

    # fallback
    return DEFAULT_INPUT

# ----------------------------
# Predict function
# ----------------------------
def predict_path(input_path: Path, topk=TOP_K, annotate_images=ANNOTATE):
    ensure_artifacts()
    if not input_path.exists():
        raise FileNotFoundError(f"Input not found: {input_path}")

    print("[INFO] Loading model and classes…")
    model = load_model(str(MODEL_H5))
    idx_to_class, ordered_classes = load_class_indices(CLASS_PKL)

    files = list_images(input_path)
    if not files:
        raise FileNotFoundError(f"No images found in {input_path}")
    print(f"[INFO] Found {len(files)} image(s) in: {input_path}")

    results = []
    class_counts = collections.Counter()

    for i, img_path in enumerate(files, 1):
        arr = load_tensor(img_path)
        probs = model.predict(arr, verbose=0)[0]
        top_idx = np.argsort(probs)[::-1][:topk]
        top_classes = [idx_to_class[int(t)] for t in top_idx]
        top_scores  = [float(probs[int(t)]) for t in top_idx]
        pred_class, pred_conf = top_classes[0], top_scores[0]

        class_counts[pred_class] += 1
        row = {"file": str(img_path), "pred_class": pred_class, "confidence": round(pred_conf, 4)}
        for j,(c,s) in enumerate(zip(top_classes, top_scores),1):
            row[f"top{j}_class"] = c; row[f"top{j}_p"] = round(s,4)
        results.append(row)

        if annotate_images:
            out_img = ANN_DIR / f"{img_path.stem}_pred.png"
            annotate_image(img_path, f"{pred_class} ({pred_conf*100:.1f}%)", out_img)

        print(f"[{i}/{len(files)}] {img_path.name} → {pred_class} ({pred_conf*100:.2f}%)")

    # Save outputs
    with open(JSON_OUT, "w", encoding="utf-8") as f: json.dump(results, f, indent=2)
    with open(CSV_OUT, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=results[0].keys())
        writer.writeheader(); writer.writerows(results)
    plot_counts_bar(class_counts, BAR_PNG)

    summary = {"input": str(input_path), "num_images": len(results), "class_counts": dict(class_counts)}
    with open(SUMMARY_JSON,"w", encoding="utf-8") as f: json.dump(summary,f,indent=2)

    print(f"[INFO] Saved predictions → {CSV_OUT} & {JSON_OUT}")
    print(f"[INFO] Saved summary → {SUMMARY_JSON}")
    if annotate_images:
        print(f"[INFO] Annotated images → {ANN_DIR}")

# ----------------------------
# Entrypoint
# ----------------------------
if __name__ == "__main__":
    # Parse robustly (works in Jupyter and CMD)
    in_path = parse_input_path(sys.argv[1:])
    predict_path(in_path, topk=TOP_K, annotate_images=ANNOTATE)


[INFO] Loading model and classes…




[INFO] Found 1311 image(s) in: C:\Users\sagni\Downloads\MRI Scan\archive\Testing
[1/1311] Te-gl_0010.jpg → glioma (98.40%)
[2/1311] Te-gl_0011.jpg → glioma (96.95%)
[3/1311] Te-gl_0012.jpg → glioma (98.99%)
[4/1311] Te-gl_0013.jpg → glioma (96.31%)
[5/1311] Te-gl_0014.jpg → glioma (99.32%)
[6/1311] Te-gl_0015.jpg → glioma (90.42%)
[7/1311] Te-gl_0016.jpg → glioma (98.16%)
[8/1311] Te-gl_0017.jpg → glioma (96.79%)
[9/1311] Te-gl_0018.jpg → glioma (99.19%)
[10/1311] Te-gl_0019.jpg → glioma (65.99%)
[11/1311] Te-gl_0020.jpg → glioma (94.08%)
[12/1311] Te-gl_0021.jpg → glioma (99.21%)
[13/1311] Te-gl_0022.jpg → glioma (99.04%)
[14/1311] Te-gl_0023.jpg → glioma (96.76%)
[15/1311] Te-gl_0024.jpg → glioma (93.80%)
[16/1311] Te-gl_0025.jpg → glioma (99.42%)
[17/1311] Te-gl_0026.jpg → glioma (99.36%)
[18/1311] Te-gl_0027.jpg → glioma (99.57%)
[19/1311] Te-gl_0028.jpg → glioma (87.26%)
[20/1311] Te-gl_0029.jpg → glioma (93.80%)
[21/1311] Te-gl_0030.jpg → glioma (77.60%)
[22/1311] Te-gl_0031.jpg 