<a href="https://colab.research.google.com/github/sandeeepmedepalli/ml-colony-classification/blob/main/ML_research.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os, random, math
import numpy as np
import pandas as pd

from google.colab import drive
drive.mount('/content/drive')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T
import torchvision.models as models
from sklearn.model_selection import train_test_split


Mounted at /content/drive


In [2]:
# Your Drive folder (contains BOTH images/ and the CSV)
BASE_DIR  = "/content/drive/MyDrive/22022540"

IMAGES_DIR = BASE_DIR                     #  images are directly here
CSV_PATH   = os.path.join(BASE_DIR, "annot_tab.csv")   # change name if your csv name is different

# Output workspace (temporary, inside Colab runtime)
WORKDIR = "/content/colony_stage2"
os.makedirs(WORKDIR, exist_ok=True)

# Patch settings
PATCH_SIZE = 100
PAD_TO_SQUARE = True

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

print("CSV_PATH:  ", CSV_PATH)
print("IMAGES_DIR:", IMAGES_DIR)


Device: cuda
CSV_PATH:   /content/drive/MyDrive/22022540/annot_tab.csv
IMAGES_DIR: /content/drive/MyDrive/22022540


In [3]:
#Copy images from Drive → local Colab
import shutil, os

LOCAL_IMG_DIR = "/content/colony_images"
os.makedirs(LOCAL_IMG_DIR, exist_ok=True)

# copy only once
for f in os.listdir(IMAGES_DIR):
    src = os.path.join(IMAGES_DIR, f)
    dst = os.path.join(LOCAL_IMG_DIR, f)
    if os.path.isfile(src) and not os.path.exists(dst): #if it is file and does not exist destination , then it copy that file
        shutil.copy2(src, dst)

IMAGES_DIR = LOCAL_IMG_DIR   # changed the image directory from drive to local colab
print("Now using local images dir:", IMAGES_DIR)


Now using local images dir: /content/colony_images


In [4]:
required_cols = [
    "label_name","bbox_x","bbox_y","bbox_width","bbox_height",
    "image_name","image_width","image_height"
]

df = pd.read_csv(CSV_PATH, encoding="utf-8-sig")

missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"CSV missing columns: {missing}\nFound: {list(df.columns)}")

# Make numeric columns numeric
for c in ["bbox_x","bbox_y","bbox_width","bbox_height","image_width","image_height"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")

# Drop bad rows
df = df.dropna(subset=required_cols).copy()
df = df[(df["bbox_width"] > 0) & (df["bbox_height"] > 0)].copy()

print("Rows (boxes):", len(df))
print("Unique classes:", df["label_name"].nunique())
print(df["label_name"].value_counts().head(24))


Rows (boxes): 56862
Unique classes: 24
label_name
sp21    11160
sp23     7067
sp22     6814
sp06     5513
sp10     4364
sp05     4102
sp19     2782
sp13     1799
sp09     1775
sp02     1530
sp18     1383
sp16     1348
sp14     1102
sp07     1087
sp15      866
sp20      853
sp24      787
sp11      481
sp12      461
sp01      397
sp08      368
sp04      304
sp03      295
sp17      224
Name: count, dtype: int64


In [5]:
# 10-class list
CLASSES_10 = ["sp02","sp05","sp06","sp07","sp10","sp14","sp16","sp19","sp21","sp23"]

# 4-class list:
# - choose 2 overlap from CLASSES_10 that have >=2000 (so we can take "fresh 1000" later)
# - choose 2 new classes NOT in CLASSES_10, each >=1000
CLASSES_4_OVERLAP = ["sp05", "sp10"]     # both have >2000 boxes in your CSV
CLASSES_4_NEW     = ["sp22", "sp13"]     # both >1000 and NOT in CLASSES_10
CLASSES_4 = CLASSES_4_OVERLAP + CLASSES_4_NEW

# Count check
counts = df["label_name"].value_counts()
def show_counts(class_list, name):
    print(f"\n{name} counts:")
    for c in class_list:
        print(f"  {c}: {int(counts.get(c, 0))}")

show_counts(CLASSES_10, "10-class")
show_counts(CLASSES_4,  "4-class")

# Safety checks for classes count to be greater than 1000 , as we are working on thousand
for c in CLASSES_10:
    if counts.get(c, 0) < 1000:
        raise ValueError(f"{c} has < 1000 boxes. Pick another class.")

for c in CLASSES_4_NEW:
    if counts.get(c, 0) < 1000:
        raise ValueError(f"{c} has < 1000 boxes. Pick another NEW class.")

for c in CLASSES_4_OVERLAP:
    if counts.get(c, 0) < 2000:
        raise ValueError(f"{c} overlap class must have >=2000 boxes to allow 'fresh 1000'.")



10-class counts:
  sp02: 1530
  sp05: 4102
  sp06: 5513
  sp07: 1087
  sp10: 4364
  sp14: 1102
  sp16: 1348
  sp19: 2782
  sp21: 11160
  sp23: 7067

4-class counts:
  sp05: 4102
  sp10: 4364
  sp22: 6814
  sp13: 1799


In [6]:
def sample_rows_for_class(df, cls, n, seed, exclude_index_set=None):
    """Sample n rows for one class. Optionally exclude some row indices."""
    sub = df[df["label_name"] == cls]
    if exclude_index_set is not None:
        sub = sub[~sub.index.isin(exclude_index_set)]
    if len(sub) < n:
        raise ValueError(f"Not enough rows for {cls}: need {n}, have {len(sub)} after exclusions.")
    return sub.sample(n=n, random_state=seed)

# 10-class pretrain sampling: 1000 per class
pretrain_parts = []
used_idx = set()

for i, cls in enumerate(CLASSES_10):
    samp = sample_rows_for_class(df, cls, n=1000, seed=SEED+i)
    pretrain_parts.append(samp)
    used_idx.update(samp.index.tolist())

df_pretrain = pd.concat(pretrain_parts).reset_index(drop=True)
print("Pretrain rows:", len(df_pretrain), " expected:", 10*1000)

# 4-class fine-tune sampling: 1000 per class
# For overlap classes, we MUST avoid reusing the same boxes used above
finetune_parts = []

for j, cls in enumerate(CLASSES_4):
    exclude = used_idx if cls in CLASSES_4_OVERLAP else None
    samp = sample_rows_for_class(df, cls, n=1000, seed=SEED+100+j, exclude_index_set=exclude)
    finetune_parts.append(samp)

df_finetune = pd.concat(finetune_parts).reset_index(drop=True)
print("Fine-tune rows:", len(df_finetune), " expected:", 4*1000)


Pretrain rows: 10000  expected: 10000
Fine-tune rows: 4000  expected: 4000


In [7]:
# Cropping function where we convert bbox to patch (with optional context expansion)
def crop_patch(
    img: Image.Image,
    x, y, w, h,
    pad_to_square=True,
    expand_factor=0.20,    # e.g., 0.20 means 20% of bbox size added on each side
    expand_px=0           # optional fixed extra pixels on each side
):
    # Expand bbox to include context
    dx = expand_px + expand_factor * float(w)
    dy = expand_px + expand_factor * float(h)

    x1 = int(round(x - dx))
    y1 = int(round(y - dy))
    x2 = int(round(x + w + dx))
    y2 = int(round(y + h + dy))

    # Clip to image boundaries
    x1 = max(0, x1); y1 = max(0, y1)
    x2 = min(img.width, x2); y2 = min(img.height, y2)

    # Safety: avoid empty crops
    if x2 <= x1: x2 = min(img.width, x1 + 1)
    if y2 <= y1: y2 = min(img.height, y1 + 1)

    patch = img.crop((x1, y1, x2, y2))

    if not pad_to_square:
        return patch

    # Pad to square (keeps aspect ratio before resize)
    side = max(patch.width, patch.height)
    new_img = Image.new("L", (side, side), color=0)  # grayscale canvas
    px = (side - patch.width) // 2
    py = (side - patch.height) // 2
    patch = patch.convert("L")
    new_img.paste(patch, (px, py))
    return new_img


In [8]:
class ColonyPatchDatasetCached(Dataset):
    def __init__(self, df_rows, images_dir, class_to_idx, transform=None):
        self.df = df_rows.reset_index(drop=True)
        self.images_dir = images_dir
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.cache = {}  # idx to PIL RGB patch

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if idx not in self.cache:
            r = self.df.iloc[idx]
            img_path = os.path.join(self.images_dir, os.path.basename(r["image_name"]))
            img = Image.open(img_path).convert("L")  # image is converted to gray scale

            patch = crop_patch(   # crop patch and resize it to 100*100
                img, r["bbox_x"], r["bbox_y"], r["bbox_width"], r["bbox_height"],
                pad_to_square=PAD_TO_SQUARE
            ).resize((PATCH_SIZE, PATCH_SIZE), resample=Image.BILINEAR)

            patch_rgb = Image.merge("RGB", (patch, patch, patch))
            self.cache[idx] = patch_rgb

        patch_rgb = self.cache[idx]
        y = self.class_to_idx[str(self.df.iloc[idx]["label_name"])]

        if self.transform:
            patch_rgb = self.transform(patch_rgb)

        return patch_rgb, y


In [9]:
# ImageNet normalization stats (because we use an ImageNet-pretrained backbone)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_tf = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(degrees=10),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

test_tf = T.Compose([
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])


In [10]:
def accuracy(pred_logits, y):
    preds = pred_logits.argmax(dim=1)
    return (preds == y).float().mean().item()

def train_one_epoch(model, loader, optim, criterion):
    model.train()
    total_loss, total_acc, n = 0.0, 0.0, 0

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)

        optim.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optim.step()

        bs = x.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy(logits, y) * bs
        n += bs

    return total_loss / n, total_acc / n

@torch.no_grad()  # helps to stop gradient tracking during evaluation
def eval_one_epoch(model, loader, criterion):
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x)
        loss = criterion(logits, y)

        bs = x.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy(logits, y) * bs
        n += bs

    return total_loss / n, total_acc / n


In [11]:
class_to_idx_10 = {c:i for i,c in enumerate(CLASSES_10)}

train_df_10, test_df_10 = train_test_split(
    df_pretrain,
    test_size=0.2,
    random_state=SEED,
    stratify=df_pretrain["label_name"]
)

ds_train_10 = ColonyPatchDatasetCached(train_df_10, IMAGES_DIR, class_to_idx_10, transform=train_tf)
ds_test_10  = ColonyPatchDatasetCached(test_df_10,  IMAGES_DIR, class_to_idx_10, transform=test_tf)

train_loader_10 = DataLoader(ds_train_10, batch_size=128, shuffle=True, num_workers=0, pin_memory=True)
test_loader_10  = DataLoader(ds_test_10,  batch_size=128, shuffle=False, num_workers=0, pin_memory=True)


print("10-class train:", len(ds_train_10), "test:", len(ds_test_10))


10-class train: 8000 test: 2000


In [12]:
#we are going to use the pretrained back bone for imagenet weights
#ResNet-18 architecture suitable for efficient transfer learning.
#Load the official pretrained ResNet18 weights trained on ImageNet-1K(1000 classes), version 1
weights = models.ResNet18_Weights.IMAGENET1K_V1
model_10 = models.resnet18(weights=weights)

# change the resnet classifier head from 1000 to 10 classes
in_features = model_10.fc.in_features
model_10.fc = nn.Linear(in_features, len(CLASSES_10))

model_10 = model_10.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_10.parameters(), lr=1e-4)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 184MB/s]


In [13]:
EPOCHS_10 = 10  # start small; you can increase later

best_acc = 0.0
for epoch in range(1, EPOCHS_10+1):
    tr_loss, tr_acc = train_one_epoch(model_10, train_loader_10, optimizer, criterion)
    te_loss, te_acc = eval_one_epoch(model_10, test_loader_10, criterion)

    print(f"[10-class] Epoch {epoch:02d} | train acc {tr_acc:.3f} | test acc {te_acc:.3f}")

    #save the model with best accuracy
    if te_acc > best_acc:
        best_acc = te_acc
        torch.save(model_10.state_dict(), os.path.join(WORKDIR, "resnet18_pretrained_10class.pt"))

print("Best 10-class test acc:", best_acc)


[10-class] Epoch 01 | train acc 0.666 | test acc 0.816
[10-class] Epoch 02 | train acc 0.865 | test acc 0.878
[10-class] Epoch 03 | train acc 0.901 | test acc 0.881
[10-class] Epoch 04 | train acc 0.914 | test acc 0.903
[10-class] Epoch 05 | train acc 0.925 | test acc 0.904
[10-class] Epoch 06 | train acc 0.936 | test acc 0.911
[10-class] Epoch 07 | train acc 0.943 | test acc 0.918
[10-class] Epoch 08 | train acc 0.942 | test acc 0.922
[10-class] Epoch 09 | train acc 0.956 | test acc 0.911
[10-class] Epoch 10 | train acc 0.956 | test acc 0.894
Best 10-class test acc: 0.9215000014305115


In [14]:
import os
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import torchvision.models as models
import torch.nn as nn

@torch.no_grad()
def evaluate_model(model, dataloader, class_names, device, title=""):
    model.eval()

    all_preds, all_targets = [], []

    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        preds = torch.argmax(logits, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_targets.append(y.cpu().numpy())

    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_targets)

    print("\n" + "="*70)
    print(title)
    print("="*70)

    cm = confusion_matrix(y_true, y_pred)
    print("Confusion Matrix:")
    print(cm)
    print()

    rep = classification_report(y_true, y_pred, target_names=class_names, digits=3)
    print("Classification Report:")
    print(rep)

    return cm, rep

In [15]:
# =========================================================
# 1) Evaluate BEST 10-class PUBLIC model (ResNet18)
# =========================================================
path_10 = os.path.join(WORKDIR, "resnet18_pretrained_10class.pt")  # <-- this matches YOUR training save

model_10_eval = models.resnet18(weights=None)
model_10_eval.fc = nn.Linear(model_10_eval.fc.in_features, len(CLASSES_10))
model_10_eval.load_state_dict(torch.load(path_10, map_location=DEVICE))
model_10_eval = model_10_eval.to(DEVICE)

cm_10, report_10 = evaluate_model(
    model_10_eval,
    test_loader_10,
    CLASSES_10,
    DEVICE,
    title="Evaluation: 10-Class PUBLIC (Best Saved Model)"
)



Evaluation: 10-Class PUBLIC (Best Saved Model)
Confusion Matrix:
[[192   2   0   0   0   0   1   0   4   1]
 [  3 191   0   2   0   2   0   0   1   1]
 [  0   0 176   1   2  10   0   3   3   5]
 [  0   0   4 185   2   1   0   2   1   5]
 [  0   0   8   2 177   3   0   2   0   8]
 [  0   0   5   0   0 192   0   3   0   0]
 [  0   0   2   0   1   1 179   8   6   3]
 [  1   0   1   0   0   2   0 195   1   0]
 [  1   0   1   1   0   0   0   0 196   1]
 [  3   2  23   3   3   2   4   0   0 160]]

Classification Report:
              precision    recall  f1-score   support

        sp02      0.960     0.960     0.960       200
        sp05      0.979     0.955     0.967       200
        sp06      0.800     0.880     0.838       200
        sp07      0.954     0.925     0.939       200
        sp10      0.957     0.885     0.919       200
        sp14      0.901     0.960     0.930       200
        sp16      0.973     0.895     0.932       200
        sp19      0.915     0.975     0.944   

In [16]:
class_to_idx_4 = {c:i for i,c in enumerate(CLASSES_4)}

train_parts = []
test_parts = []

for cls in CLASSES_4:
    sub = df_finetune[df_finetune["label_name"] == cls].sample(frac=1.0, random_state=SEED)  # shuffle
    train_parts.append(sub.iloc[:800])
    test_parts.append(sub.iloc[800:1000])

train_df_4 = pd.concat(train_parts).reset_index(drop=True)
test_df_4  = pd.concat(test_parts).reset_index(drop=True)

print("4-class train:", len(train_df_4), " expected:", 4*800)
print("4-class test: ", len(test_df_4),  " expected:", 4*200)
print(train_df_4["label_name"].value_counts())


4-class train: 3200  expected: 3200
4-class test:  800  expected: 800
label_name
sp05    800
sp10    800
sp22    800
sp13    800
Name: count, dtype: int64


In [17]:
ds_train_4 = ColonyPatchDatasetCached(train_df_4, IMAGES_DIR, class_to_idx_4, transform=train_tf)
ds_test_4  = ColonyPatchDatasetCached(test_df_4,  IMAGES_DIR, class_to_idx_4, transform=test_tf)

train_loader_4 = DataLoader(ds_train_4, batch_size=128, shuffle=True, num_workers=0, pin_memory=True)
test_loader_4  = DataLoader(ds_test_4,  batch_size=128, shuffle=False, num_workers=0, pin_memory=True)


In [18]:
# Start from the best 10-class checkpoint
model_ft = models.resnet18(weights=None)    # fresh resnet , without pretrained weights
in_features = model_ft.fc.in_features   # helps in rebuilding the final layer
model_ft.fc = nn.Linear(in_features, len(CLASSES_10))  # temporary to load weights
model_ft.load_state_dict(torch.load(os.path.join(WORKDIR, "resnet18_pretrained_10class.pt"), map_location="cpu"))

# Freeze everything (feature extractor)
for p in model_ft.parameters():
    p.requires_grad = False

# Replace final layer with a NEW 4-class head
model_ft.fc = nn.Linear(in_features, len(CLASSES_4))

# Only the head will train
for p in model_ft.fc.parameters():
    p.requires_grad = True

model_ft = model_ft.to(DEVICE)

criterion_ft = nn.CrossEntropyLoss()
optimizer_ft = torch.optim.Adam(model_ft.fc.parameters(), lr=1e-3)  # a bit higher since only head trains


In [19]:
EPOCHS_4 = 10

best_acc_4 = 0.0
for epoch in range(1, EPOCHS_4+1):
    tr_loss, tr_acc = train_one_epoch(model_ft, train_loader_4, optimizer_ft, criterion_ft)
    te_loss, te_acc = eval_one_epoch(model_ft, test_loader_4, criterion_ft)

    print(f"[4-class FT] Epoch {epoch:02d} | train acc {tr_acc:.3f} | test acc {te_acc:.3f}")

    if te_acc > best_acc_4:
        best_acc_4 = te_acc
        torch.save(model_ft.state_dict(), os.path.join(WORKDIR, "resnet18_finetuned_4class.pt"))

print("Best 4-class test acc:", best_acc_4)
print("4 classes:", CLASSES_4)


[4-class FT] Epoch 01 | train acc 0.748 | test acc 0.895
[4-class FT] Epoch 02 | train acc 0.907 | test acc 0.917
[4-class FT] Epoch 03 | train acc 0.919 | test acc 0.930
[4-class FT] Epoch 04 | train acc 0.922 | test acc 0.934
[4-class FT] Epoch 05 | train acc 0.931 | test acc 0.936
[4-class FT] Epoch 06 | train acc 0.931 | test acc 0.936
[4-class FT] Epoch 07 | train acc 0.938 | test acc 0.941
[4-class FT] Epoch 08 | train acc 0.941 | test acc 0.940
[4-class FT] Epoch 09 | train acc 0.938 | test acc 0.945
[4-class FT] Epoch 10 | train acc 0.941 | test acc 0.940
Best 4-class test acc: 0.945
4 classes: ['sp05', 'sp10', 'sp22', 'sp13']


In [20]:
# ============================================================
# Full fine-tuning (unfreeze backbone + keep head)
# ============================================================

import os
import torch

# Unfreeze ALL parameters (backbone + head)
for p in model_ft.parameters():
    p.requires_grad = True

# Build optimizer for full fine-tuning
# smaller LR for backbone, bigger LR for head
backbone_params = []
head_params = []
for name, p in model_ft.named_parameters():
    if not p.requires_grad:
        continue
    if name.startswith("fc."):
        head_params.append(p)
    else:
        backbone_params.append(p)

optimizer_full = torch.optim.Adam([
    {"params": backbone_params, "lr": 1e-5},  # backbone learns slowly
    {"params": head_params,     "lr": 1e-4},  # head learns faster
])

# Continue training and track train/test accuracy
EPOCHS_FULL = 10
best_acc_full = 0.0
best_path = os.path.join(WORKDIR, "resnet18_finetuned_4class_full.pt")

for epoch in range(1, EPOCHS_FULL + 1):
    tr_loss, tr_acc = train_one_epoch(model_ft, train_loader_4, optimizer_full, criterion_ft)
    te_loss, te_acc = eval_one_epoch(model_ft, test_loader_4, criterion_ft)

    print(f"[4-class FULL FT] Epoch {epoch:02d} | train acc {tr_acc:.3f} | test acc {te_acc:.3f}")

    # Save best model by test accuracy
    if te_acc > best_acc_full:
        best_acc_full = te_acc
        torch.save(model_ft.state_dict(), best_path)

print("Best 4-class test acc (full fine-tune):", best_acc_full)
print("Saved best full-finetuned model to:", best_path)


[4-class FULL FT] Epoch 01 | train acc 0.950 | test acc 0.943
[4-class FULL FT] Epoch 02 | train acc 0.961 | test acc 0.951
[4-class FULL FT] Epoch 03 | train acc 0.970 | test acc 0.951
[4-class FULL FT] Epoch 04 | train acc 0.971 | test acc 0.956
[4-class FULL FT] Epoch 05 | train acc 0.975 | test acc 0.959
[4-class FULL FT] Epoch 06 | train acc 0.973 | test acc 0.958
[4-class FULL FT] Epoch 07 | train acc 0.978 | test acc 0.958
[4-class FULL FT] Epoch 08 | train acc 0.978 | test acc 0.959
[4-class FULL FT] Epoch 09 | train acc 0.983 | test acc 0.963
[4-class FULL FT] Epoch 10 | train acc 0.981 | test acc 0.963
Best 4-class test acc (full fine-tune): 0.9625
Saved best full-finetuned model to: /content/colony_stage2/resnet18_finetuned_4class_full.pt


In [21]:
# =========================================================
# 2) Evaluate BEST 4-class FULL fine-tuned model (ResNet18)
# =========================================================
path_4 = os.path.join(WORKDIR, "resnet18_finetuned_4class_full.pt")  # <-- update if your filename differs

model_4_eval = models.resnet18(weights=None)

# IMPORTANT:
# your checkpoint was trained with a 4-class head (after fine-tuning),
# so we rebuild ResNet with 4 outputs.
model_4_eval.fc = nn.Linear(model_4_eval.fc.in_features, len(CLASSES_4))

model_4_eval.load_state_dict(torch.load(path_4, map_location=DEVICE))
model_4_eval = model_4_eval.to(DEVICE)

cm_4, report_4 = evaluate_model(
    model_4_eval,
    test_loader_4,
    CLASSES_4,
    DEVICE,
    title="Evaluation: 4-Class PUBLIC Fine-Tuned (Best Saved Model)"
)


Evaluation: 4-Class PUBLIC Fine-Tuned (Best Saved Model)
Confusion Matrix:
[[197   1   0   2]
 [  2 185  13   0]
 [  0   3 196   1]
 [  3   0   5 192]]

Classification Report:
              precision    recall  f1-score   support

        sp05      0.975     0.985     0.980       200
        sp10      0.979     0.925     0.951       200
        sp22      0.916     0.980     0.947       200
        sp13      0.985     0.960     0.972       200

    accuracy                          0.963       800
   macro avg      0.964     0.963     0.963       800
weighted avg      0.964     0.963     0.963       800



In [23]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# ============================================================
# 1) Confusion Matrix PNG (Blue/White like PDF)
# ============================================================
def save_cm_png(cm, class_names, out_path, title="Confusion Matrix", normalize=False, show_colorbar=True):
    """
    Paper-style Confusion Matrix (PDF-like):
    - Blue/white colormap (Blues)
    - Bigger fonts
    - Thicker gridlines (black)
    - Square cells
    - Clean layout
    """
    cm = np.array(cm, dtype=float)

    if normalize:
        row_sums = cm.sum(axis=1, keepdims=True) + 1e-9
        cm_show = cm / row_sums
        vmin, vmax = 0.0, 1.0
    else:
        cm_show = cm
        vmin, vmax = 0.0, float(cm_show.max()) if cm_show.size else 1.0

    # ✅ Bigger figure + better paper layout
    fig, ax = plt.subplots(figsize=(11, 9))
    ax.set_facecolor("white")

    # ✅ Blue-white look like the PDF
    im = ax.imshow(cm_show, interpolation="nearest", cmap="Blues", vmin=vmin, vmax=vmax)

    # ✅ Title and labels with bigger fonts
    ax.set_title(title, fontsize=18, fontweight="bold", pad=16)
    ax.set_xlabel("Predicted label", fontsize=16, fontweight="bold", labelpad=10)
    ax.set_ylabel("True label", fontsize=16, fontweight="bold", labelpad=10)

    # ✅ Ticks with bigger fonts
    ticks = np.arange(len(class_names))
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_xticklabels(class_names, rotation=45, ha="right", fontsize=14)
    ax.set_yticklabels(class_names, fontsize=14)

    # ✅ Optional colorbar (papers sometimes keep it, sometimes not)
    if show_colorbar:
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=12)

    # ✅ Thick gridlines like PDF
    ax.set_xticks(np.arange(-.5, len(class_names), 1), minor=True)
    ax.set_yticks(np.arange(-.5, len(class_names), 1), minor=True)
    ax.grid(which="minor", color="black", linestyle="-", linewidth=1.5)
    ax.tick_params(which="minor", bottom=False, left=False)

    # ✅ Annotate cells with larger text
    # threshold is based on displayed range
    thresh = (vmax - vmin) * 0.6 + vmin

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if normalize:
                txt = f"{cm_show[i, j]*100:.1f}%\n({int(cm[i, j])})"
            else:
                txt = str(int(cm[i, j]))

            ax.text(
                j, i, txt,
                ha="center", va="center",
                fontsize=13, fontweight="bold",
                color="white" if cm_show[i, j] > thresh else "black"
            )

    ax.set_aspect("equal")  # ✅ square cells
    fig.tight_layout()

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    print("✅ Saved CM PNG (paper-style):", out_path)


# ============================================================
# 2) Classification Report -> Excel (.xlsx)
#    Works even if report_* is a STRING (sklearn text report)
# ============================================================
def _parse_classification_report_string(report_str: str) -> pd.DataFrame:
    """
    Parses sklearn.metrics.classification_report (text) into a DataFrame.
    Handles typical sklearn formatting.
    """
    lines = [ln.strip() for ln in str(report_str).splitlines() if ln.strip()]
    rows = []

    # Example lines look like:
    # classA   0.95   0.90   0.92   50
    # accuracy 0.93   200
    # macro avg 0.94  0.93  0.93  200
    # weighted avg ...

    for ln in lines:
        # skip header line if present
        if ln.lower().startswith("precision") and "recall" in ln.lower():
            continue

        parts = re.split(r"\s+", ln)
        if len(parts) < 2:
            continue

        # accuracy line is special: "accuracy 0.932 200"
        if parts[0].lower() == "accuracy":
            # Sometimes: accuracy 0.93 200
            if len(parts) >= 3:
                rows.append({
                    "label": "accuracy",
                    "precision": np.nan,
                    "recall": np.nan,
                    "f1-score": float(parts[1]),
                    "support": float(parts[2])
                })
            continue

        # macro avg / weighted avg are 2-word labels
        if parts[0].lower() in ["macro", "weighted"] and len(parts) >= 6:
            label = parts[0] + " " + parts[1]  # "macro avg"
            precision, recall, f1, support = parts[2], parts[3], parts[4], parts[5]
            rows.append({
                "label": label,
                "precision": float(precision),
                "recall": float(recall),
                "f1-score": float(f1),
                "support": float(support)
            })
            continue

        # normal per-class line: label + 4 numbers
        if len(parts) >= 5:
            label = parts[0]
            precision, recall, f1, support = parts[1], parts[2], parts[3], parts[4]
            # Some class names may contain spaces; if so, this parser won't catch it.
            # In your code, class names look like "sp01" etc, so it's safe.
            try:
                rows.append({
                    "label": label,
                    "precision": float(precision),
                    "recall": float(recall),
                    "f1-score": float(f1),
                    "support": float(support)
                })
            except:
                pass

    df = pd.DataFrame(rows)
    return df

def save_report_excel(report_obj, out_xlsx_path, sheet_name="Report"):
    """
    Saves classification report to an Excel file.
    - report_obj can be either:
        (a) sklearn text report string, or
        (b) a DataFrame already
    """
    if isinstance(report_obj, pd.DataFrame):
        df = report_obj.copy()
    else:
        df = _parse_classification_report_string(report_obj)

    os.makedirs(os.path.dirname(out_xlsx_path), exist_ok=True)

    # Write xlsx
    with pd.ExcelWriter(out_xlsx_path, engine="openpyxl") as writer:
        df.to_excel(writer, index=False, sheet_name=sheet_name)

        # basic formatting: auto column widths
        ws = writer.sheets[sheet_name]
        for col in ws.columns:
            max_len = 0
            col_letter = col[0].column_letter
            for cell in col:
                try:
                    max_len = max(max_len, len(str(cell.value)))
                except:
                    pass
            ws.column_dimensions[col_letter].width = min(max_len + 2, 35)

    print("✅ Saved report Excel:", out_xlsx_path)

# ============================================================
# 3) OUTPUT FOLDER (uses WORKDIR if present)
# ============================================================
OUT_DIR = os.path.join(WORKDIR if "WORKDIR" in globals() else ".", "reports_outputs")
os.makedirs(OUT_DIR, exist_ok=True)
print("✅ Outputs will be saved under:", OUT_DIR)

# ============================================================
# 4) SAVE: 10-class (public)
# ============================================================
assert "cm_10" in globals() and "report_10" in globals(), "Run the 10-class evaluation cell first (cm_10/report_10)."
assert "CLASSES_10" in globals(), "CLASSES_10 not found."

save_cm_png(cm_10, CLASSES_10, os.path.join(OUT_DIR, "cm_10class_blue.png"),
            title="Confusion Matrix — Best 10-Class Public Model", normalize=False, show_colorbar=True)

save_cm_png(cm_10, CLASSES_10, os.path.join(OUT_DIR, "cm_10class_blue_norm.png"),
            title="Confusion Matrix (Normalized) — Best 10-Class Public Model", normalize=True, show_colorbar=True)

save_report_excel(report_10, os.path.join(OUT_DIR, "report_10class.xlsx"), sheet_name="10-class report")

# ============================================================
# 5) SAVE: 4-class (fine-tuned)
# ============================================================
assert "cm_4" in globals() and "report_4" in globals(), "Run the 4-class evaluation cell first (cm_4/report_4)."
assert "CLASSES_4" in globals(), "CLASSES_4 not found."

save_cm_png(cm_4, CLASSES_4, os.path.join(OUT_DIR, "cm_4class_blue.png"),
            title="Confusion Matrix — Best 4-Class Fine-Tuned Model", normalize=False, show_colorbar=True)

save_cm_png(cm_4, CLASSES_4, os.path.join(OUT_DIR, "cm_4class_blue_norm.png"),
            title="Confusion Matrix (Normalized) — Best 4-Class Fine-Tuned Model", normalize=True, show_colorbar=True)

save_report_excel(report_4, os.path.join(OUT_DIR, "report_4class.xlsx"), sheet_name="4-class report")

print("\n✅ Done. Check outputs in:", OUT_DIR)


✅ Outputs will be saved under: /content/colony_stage2/reports_outputs
✅ Saved CM PNG (paper-style): /content/colony_stage2/reports_outputs/cm_10class_blue.png
✅ Saved CM PNG (paper-style): /content/colony_stage2/reports_outputs/cm_10class_blue_norm.png
✅ Saved report Excel: /content/colony_stage2/reports_outputs/report_10class.xlsx
✅ Saved CM PNG (paper-style): /content/colony_stage2/reports_outputs/cm_4class_blue.png
✅ Saved CM PNG (paper-style): /content/colony_stage2/reports_outputs/cm_4class_blue_norm.png
✅ Saved report Excel: /content/colony_stage2/reports_outputs/report_4class.xlsx

✅ Done. Check outputs in: /content/colony_stage2/reports_outputs
