<a href="https://colab.research.google.com/github/sandeeepmedepalli/ml-colony-classification/blob/main/mixing_old_and_new_enhanced_code_of_ML_research.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os, random, math
import numpy as np
import pandas as pd

from google.colab import drive
drive.mount('/content/drive')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T
import torchvision.models as models
from sklearn.model_selection import train_test_split


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# ====== PATHS (separate pretrain vs ground truth) ======
BASE_DIR = "/content/drive/MyDrive/22022540"

# Public dataset (10-class pretrain)
CSV_PATH   = os.path.join(BASE_DIR, "annot_tab.csv")
IMAGES_DIR_PUBLIC = BASE_DIR  # images are directly here


# ===== Dataset A (Fine-tune train) =====
FT_GT_ZIP_PATH   = "/content/drive/MyDrive/22022540/ground_truth/ground_truth_one/ground_truth_dataset.zip"
IMAGES_DIR_FT    = "/content/drive/MyDrive/22022540/ground_truth/ground_truth_one"

# ===== Dataset B (New test dataset) =====
TEST_GT_ZIP_PATH = "/content/drive/MyDrive/22022540/ground_truth/ground_truth_updated/updated_groud_truth.zip"
IMAGES_DIR_TEST  = "/content/drive/MyDrive/22022540/ground_truth/ground_truth_updated"


# Output workspaces (separate)
WORKDIR_PRETRAIN = "/content/colony_stage2_pretrain10"
WORKDIR_GT       = "/content/colony_stage2_groundtruth"
os.makedirs(WORKDIR_PRETRAIN, exist_ok=True)
os.makedirs(WORKDIR_GT, exist_ok=True)

# Patch settings
PATCH_SIZE = 100
PAD_TO_SQUARE = True

# How much extra context around the colony bbox
BOX_EXPAND = 0.20


# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
print("CSV_PATH:  ", CSV_PATH)

Device: cuda
CSV_PATH:   /content/drive/MyDrive/22022540/annot_tab.csv


In [3]:
#Copy images from Drive → local Colab
import shutil, os

LOCAL_IMG_DIR = "/content/colony_images"
os.makedirs(LOCAL_IMG_DIR, exist_ok=True)

for f in os.listdir(IMAGES_DIR_PUBLIC):
    src = os.path.join(IMAGES_DIR_PUBLIC, f)
    dst = os.path.join(LOCAL_IMG_DIR, f)
    if os.path.isfile(src) and not os.path.exists(dst):
        shutil.copy2(src, dst)

IMAGES_DIR_PUBLIC = LOCAL_IMG_DIR
print("Now using local images:", IMAGES_DIR_PUBLIC)

Now using local images: /content/colony_images


In [4]:
required_cols = [
    "label_name","bbox_x","bbox_y","bbox_width","bbox_height",
    "image_name","image_width","image_height"
]

df = pd.read_csv(CSV_PATH, encoding="utf-8-sig")

missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"CSV missing columns: {missing}\nFound: {list(df.columns)}")

# Make numeric columns numeric
for c in ["bbox_x","bbox_y","bbox_width","bbox_height","image_width","image_height"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")

# Drop bad rows
df = df.dropna(subset=required_cols).copy()
df = df[(df["bbox_width"] > 0) & (df["bbox_height"] > 0)].copy()

print("Rows (boxes):", len(df))
print("Unique classes:", df["label_name"].nunique())
print(df["label_name"].value_counts().head(24))


Rows (boxes): 56862
Unique classes: 24
label_name
sp21    11160
sp23     7067
sp22     6814
sp06     5513
sp10     4364
sp05     4102
sp19     2782
sp13     1799
sp09     1775
sp02     1530
sp18     1383
sp16     1348
sp14     1102
sp07     1087
sp15      866
sp20      853
sp24      787
sp11      481
sp12      461
sp01      397
sp08      368
sp04      304
sp03      295
sp17      224
Name: count, dtype: int64


In [5]:
# 10-class list
CLASSES_10 = ["sp02","sp05","sp06","sp07","sp10","sp14","sp16","sp19","sp21","sp23"]

# Count check
counts = df["label_name"].value_counts()
def show_counts(class_list, name):
    print(f"\n{name} counts:")
    for c in class_list:
        print(f"  {c}: {int(counts.get(c, 0))}")

show_counts(CLASSES_10, "10-class")

# Safety checks for classes count to be greater than 1000 , as we are working on thousand
for c in CLASSES_10:
    if counts.get(c, 0) < 1000:
        raise ValueError(f"{c} has < 1000 boxes. Pick another class.")


10-class counts:
  sp02: 1530
  sp05: 4102
  sp06: 5513
  sp07: 1087
  sp10: 4364
  sp14: 1102
  sp16: 1348
  sp19: 2782
  sp21: 11160
  sp23: 7067


In [6]:
def sample_rows_for_class(df, cls, n, seed, exclude_index_set=None):
    """Sample n rows for one class. Optionally exclude some row indices."""
    sub = df[df["label_name"] == cls]
    if exclude_index_set is not None:
        sub = sub[~sub.index.isin(exclude_index_set)]
    if len(sub) < n:
        raise ValueError(f"Not enough rows for {cls}: need {n}, have {len(sub)} after exclusions.")
    return sub.sample(n=n, random_state=seed)

# 10-class pretrain sampling: 1000 per class
pretrain_parts = []
used_idx = set()

for i, cls in enumerate(CLASSES_10):
    samp = sample_rows_for_class(df, cls, n=1000, seed=SEED+i)
    pretrain_parts.append(samp)
    used_idx.update(samp.index.tolist())

df_pretrain = pd.concat(pretrain_parts).reset_index(drop=True)
print("Pretrain rows:", len(df_pretrain), " expected:", 10*1000)

Pretrain rows: 10000  expected: 10000


In [7]:
# Cropping function where we convert bbox to patch (with optional context expansion)
def crop_patch(
    img: Image.Image,
    x, y, w, h,
    pad_to_square=True,
    expand_factor=0.20,    # e.g., 0.20 means 20% of bbox size added on each side
    expand_px=0           # optional fixed extra pixels on each side
):
    # Expand bbox to include context
    dx = expand_px + expand_factor * float(w)
    dy = expand_px + expand_factor * float(h)

    x1 = int(round(x - dx))
    y1 = int(round(y - dy))
    x2 = int(round(x + w + dx))
    y2 = int(round(y + h + dy))

    # Clip to image boundaries
    x1 = max(0, x1); y1 = max(0, y1)
    x2 = min(img.width, x2); y2 = min(img.height, y2)

    # Safety: avoid empty crops
    if x2 <= x1: x2 = min(img.width, x1 + 1)
    if y2 <= y1: y2 = min(img.height, y1 + 1)

    patch = img.crop((x1, y1, x2, y2))

    if not pad_to_square:
        return patch

    # Pad to square (keeps aspect ratio before resize)
    side = max(patch.width, patch.height)
    new_img = Image.new("L", (side, side), color=0)  # grayscale canvas
    px = (side - patch.width) // 2
    py = (side - patch.height) // 2
    patch = patch.convert("L")
    new_img.paste(patch, (px, py))
    return new_img


In [8]:
from PIL import Image
from torch.utils.data import Dataset
import os
import torch

class ColonyPatchDatasetCached(Dataset):
    """
    REAL caching:
    - Caches the final tensor (after crop + transform) in memory
    - Reuses it on next epochs
    """
    def __init__(self, df, images_dir, class_to_idx, transform=None, cache_in_ram=True):
        self.df = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.cache_in_ram = cache_in_ram
        self._cache = {}  # idx -> (tensor_x, tensor_y)

    def __len__(self):
        return len(self.df)

    def _resolve_img_path(self, row):
        if "image_path" in self.df.columns:
            p = str(row["image_path"])
            if os.path.exists(p):
                return p

        name = str(row["image_name"])
        p2 = os.path.join(self.images_dir, name)
        if os.path.exists(p2):
            return p2

        p3 = os.path.join(self.images_dir, os.path.basename(name))
        if os.path.exists(p3):
            return p3

        raise FileNotFoundError(
            f"Could not find image.\n"
            f"image_name={name}\n"
            f"image_path={row.get('image_path', None)}\n"
            f"tried={p2}\ntried={p3}\nimages_dir={self.images_dir}"
        )

    def __getitem__(self, idx):
        # ✅ return cached sample if available
        if self.cache_in_ram and idx in self._cache:
            return self._cache[idx]

        r = self.df.iloc[idx]
        img_path = self._resolve_img_path(r)

        img = Image.open(img_path).convert("L")

        x = float(r["bbox_x"]); y = float(r["bbox_y"])
        w = float(r["bbox_width"]); h = float(r["bbox_height"])

        patch = crop_patch(img, x, y, w, h)  # must return fixed-size PIL patch
        patch_rgb = Image.merge("RGB", (patch, patch, patch))

        if self.transform is not None:
            x_out = self.transform(patch_rgb)
        else:
            # if no transform, convert to tensor so batching works
            x_out = torch.from_numpy(np.array(patch_rgb)).permute(2,0,1).float() / 255.0

        if "label_idx" in self.df.columns:
            y_out = torch.tensor(int(r["label_idx"]), dtype=torch.long)
        else:
            y_out = torch.tensor(int(self.class_to_idx[str(r["label_name"])]), dtype=torch.long)

        if self.cache_in_ram:
            self._cache[idx] = (x_out, y_out)

        return x_out, y_out


In [9]:
# ImageNet normalization stats (because we use an ImageNet-pretrained backbone)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_tf = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(10),
    T.RandomResizedCrop(100, scale=(0.85, 1.0)),   # helps robustness
    T.ColorJitter(brightness=0.2, contrast=0.2),   # important for new images
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])


test_tf = T.Compose([
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])


In [10]:
def accuracy(pred_logits, y):
    preds = pred_logits.argmax(dim=1)
    return (preds == y).float().mean().item()

def train_one_epoch(model, loader, optim, criterion):
    model.train()
    total_loss, total_acc, n = 0.0, 0.0, 0

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)

        optim.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optim.step()

        bs = x.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy(logits, y) * bs
        n += bs

    return total_loss / n, total_acc / n

@torch.no_grad()  # helps to stop gradient tracking during evaluation
def eval_one_epoch(model, loader, criterion):
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x)
        loss = criterion(logits, y)

        bs = x.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy(logits, y) * bs
        n += bs

    return total_loss / n, total_acc / n


In [11]:
class_to_idx_10 = {c:i for i,c in enumerate(CLASSES_10)}

train_df_10, test_df_10 = train_test_split(
    df_pretrain,
    test_size=0.2,
    random_state=SEED,
    stratify=df_pretrain["label_name"]
)

ds_train_10 = ColonyPatchDatasetCached(train_df_10, IMAGES_DIR_PUBLIC, class_to_idx_10, transform=train_tf)
ds_test_10  = ColonyPatchDatasetCached(test_df_10,  IMAGES_DIR_PUBLIC, class_to_idx_10, transform=test_tf)

train_loader_10 = DataLoader(ds_train_10, batch_size=128, shuffle=True, num_workers=0, pin_memory=True)
test_loader_10  = DataLoader(ds_test_10,  batch_size=128, shuffle=False, num_workers=0, pin_memory=True)


print("10-class train:", len(ds_train_10), "test:", len(ds_test_10))


10-class train: 8000 test: 2000


In [12]:
#we are going to use the pretrained back bone for imagenet weights
#ResNet-18 architecture suitable for efficient transfer learning.
#Load the official pretrained ResNet18 weights trained on ImageNet-1K(1000 classes), version 1
weights = models.ResNet18_Weights.IMAGENET1K_V1
model_10 = models.resnet18(weights=weights)

# change the resnet classifier head from 1000 to 10 classes
in_features = model_10.fc.in_features
model_10.fc = nn.Linear(in_features, len(CLASSES_10))

model_10 = model_10.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_10.parameters(), lr=1e-4)


In [13]:
# ============================
# FIX: Make all 10-class images same size (prevents stack error)
# ============================

import torchvision.transforms as T
from torch.utils.data import DataLoader

# 1) Set a fixed size (ResNet standard)
IMG_SIZE_10 = 224

# 2) Define transforms that ALWAYS resize
train_tf_10 = T.Compose([
    T.Resize((IMG_SIZE_10, IMG_SIZE_10)),
    T.RandomHorizontalFlip(p=0.5),
    T.ToTensor(),
])

test_tf_10 = T.Compose([
    T.Resize((IMG_SIZE_10, IMG_SIZE_10)),
    T.ToTensor(),
])

# 3) Rebuild datasets using the SAME dataset objects you already created
#    This works if your dataset has `.transform` field like torchvision ImageFolder
assert "ds_train_10" in globals(), "ds_train_10 not found. Find the cell where you created it."
assert "ds_test_10" in globals(),  "ds_test_10 not found. Find the cell where you created it."

ds_train_10.transform = train_tf_10
ds_test_10.transform  = test_tf_10

# 4) Rebuild loaders (important!)
BATCH_SIZE_10 = 64

train_loader_10 = DataLoader(
    ds_train_10, batch_size=BATCH_SIZE_10, shuffle=True,
    num_workers=0, pin_memory=True
)

test_loader_10 = DataLoader(
    ds_test_10, batch_size=BATCH_SIZE_10, shuffle=False,
    num_workers=0, pin_memory=True
)

# 5) Quick verification
x, y = next(iter(test_loader_10))
print("✅ test_loader_10 batch shape:", x.shape)  # must be [B, 3, 224, 224]


✅ test_loader_10 batch shape: torch.Size([64, 3, 224, 224])


In [14]:
EPOCHS_10 = 10  # start small; you can increase later

best_acc = 0.0
for epoch in range(1, EPOCHS_10+1):
    tr_loss, tr_acc = train_one_epoch(model_10, train_loader_10, optimizer, criterion)
    te_loss, te_acc = eval_one_epoch(model_10, test_loader_10, criterion)

    print(f"[10-class] Epoch {epoch:02d} | train acc {tr_acc:.3f} | test acc {te_acc:.3f}")

    #save the model with best accuracy
    if te_acc > best_acc:
        best_acc = te_acc
        torch.save(model_10.state_dict(), os.path.join(WORKDIR_PRETRAIN, "resnet18_pretrained_10class.pt"))

print("Best 10-class test acc:", best_acc)


[10-class] Epoch 01 | train acc 0.792 | test acc 0.875
[10-class] Epoch 02 | train acc 0.933 | test acc 0.913
[10-class] Epoch 03 | train acc 0.961 | test acc 0.923
[10-class] Epoch 04 | train acc 0.971 | test acc 0.926
[10-class] Epoch 05 | train acc 0.976 | test acc 0.904
[10-class] Epoch 06 | train acc 0.977 | test acc 0.918
[10-class] Epoch 07 | train acc 0.985 | test acc 0.916
[10-class] Epoch 08 | train acc 0.990 | test acc 0.934
[10-class] Epoch 09 | train acc 0.992 | test acc 0.920
[10-class] Epoch 10 | train acc 0.992 | test acc 0.891
Best 10-class test acc: 0.934


In [15]:
import os, zipfile, yaml
import pandas as pd
from PIL import Image

def parse_cvat_yolo_zip_to_df(gt_zip_path, images_dir, seed=42, source_tag=""):
    GT_DIR = "/content/_tmp_gt_extract"
    if os.path.exists(GT_DIR):
        import shutil; shutil.rmtree(GT_DIR)
    os.makedirs(GT_DIR, exist_ok=True)

    with zipfile.ZipFile(gt_zip_path, "r") as z:
        z.extractall(GT_DIR)

    yaml_path = os.path.join(GT_DIR, "data.yaml")
    with open(yaml_path, "r") as f:
        y = yaml.safe_load(f)

    names = y["names"]
    if isinstance(names, dict):
        id_to_class = {int(k): v for k, v in names.items()}
    else:
        id_to_class = {i: v for i, v in enumerate(names)}

    classes = [id_to_class[i] for i in sorted(id_to_class.keys())]

    labels_dir = os.path.join(GT_DIR, "labels", "train")
    if not os.path.isdir(labels_dir):
        labels_dir = os.path.join(GT_DIR, "labels")
    if not os.path.isdir(labels_dir):
        raise FileNotFoundError("Could not find labels folder inside GT zip.")

    def find_image_path(stem):
        for ext in [".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"]:
            p = os.path.join(images_dir, stem + ext)
            if os.path.exists(p):
                return p
        return None

    rows, missing = [], []
    for txt_name in os.listdir(labels_dir):
        if not txt_name.endswith(".txt"):
            continue

        stem = os.path.splitext(txt_name)[0]
        img_path = find_image_path(stem)
        if img_path is None:
            missing.append(stem)
            continue

        img = Image.open(img_path)
        W, H = img.size

        with open(os.path.join(labels_dir, txt_name), "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) != 5:
                    continue

                cls_id = int(float(parts[0]))
                cx, cy, w, h = map(float, parts[1:])

                bw = w * W
                bh = h * H
                x1 = (cx * W) - bw/2
                y1 = (cy * H) - bh/2

                rows.append({
                    "label_name": id_to_class[cls_id],
                    "bbox_x": x1, "bbox_y": y1,
                    "bbox_width": bw, "bbox_height": bh,
                    "image_name": os.path.basename(img_path),
                    "image_path": img_path,                 # ✅ FIX
                    "source": source_tag,                   # ✅ helps debugging
                    "image_width": W, "image_height": H
                })

    df = pd.DataFrame(rows)

    print("Parsed:", gt_zip_path)
    print("Boxes:", len(df), "Images:", df["image_path"].nunique())
    print("Classes:", classes)
    if missing:
        print("WARNING: Missing images for some labels (examples):", missing[:10])

    return df, classes


In [16]:
df_ft, classes_ft   = parse_cvat_yolo_zip_to_df(FT_GT_ZIP_PATH, IMAGES_DIR_FT, seed=SEED, source_tag="old")
df_test, classes_ts = parse_cvat_yolo_zip_to_df(TEST_GT_ZIP_PATH, IMAGES_DIR_TEST, seed=SEED, source_tag="new")

# ✅ FIX: compare as sets (order doesn't matter)
if set(classes_ft) != set(classes_ts):
    raise ValueError(
        f"Class mismatch!\nTrain classes: {classes_ft}\nTest classes:  {classes_ts}"
    )

VALID_CLASSES = ["sp01", "sp02", "sp03", "sp04"]  # keep these classes

df_ft   = df_ft[df_ft["label_name"].isin(VALID_CLASSES)].reset_index(drop=True)
df_test = df_test[df_test["label_name"].isin(VALID_CLASSES)].reset_index(drop=True)

df_all = pd.concat([df_ft, df_test], ignore_index=True)
df_all = df_all.sample(frac=1.0, random_state=SEED).reset_index(drop=True)

# ✅ Canonical order (stable)
CLASSES_FT = sorted(df_all["label_name"].unique().tolist())
class_to_idx_ft = {c:i for i,c in enumerate(CLASSES_FT)}

print("Using classes:", CLASSES_FT)
print("Class → idx mapping:", class_to_idx_ft)
print("Class counts:\n", df_all["label_name"].value_counts())
print("Unique images per class:\n", df_all.groupby("label_name")["image_path"].nunique())


Parsed: /content/drive/MyDrive/22022540/ground_truth/ground_truth_one/ground_truth_dataset.zip
Boxes: 254 Images: 4
Classes: ['sp01', 'sp02', 'sp03', 'sp04']
Parsed: /content/drive/MyDrive/22022540/ground_truth/ground_truth_updated/updated_groud_truth.zip
Boxes: 265 Images: 4
Classes: ['sp01', 'sp03', 'sp04', 'sp02']
Using classes: ['sp01', 'sp02', 'sp03', 'sp04']
Class → idx mapping: {'sp01': 0, 'sp02': 1, 'sp03': 2, 'sp04': 3}
Class counts:
 label_name
sp02    230
sp01    100
sp04    100
sp03     89
Name: count, dtype: int64
Unique images per class:
 label_name
sp01    2
sp02    2
sp03    2
sp04    2
Name: image_path, dtype: int64


In [17]:
from sklearn.model_selection import train_test_split
import pandas as pd

print("Old set boxes:", len(df_ft))
print("New set boxes:", len(df_test))

# 1) MIX old + new
df_all = pd.concat([df_ft, df_test], ignore_index=True)

# (Optional) shuffle once for safety
df_all = df_all.sample(frac=1.0, random_state=SEED).reset_index(drop=True)

print("\nMixed boxes total:", len(df_all))
print("Class distribution (mixed):")
print(df_all["label_name"].value_counts())

# 2) 80/20 split (stratified by class)
try:
    train_df_ft, test_df_new = train_test_split(
        df_all,
        test_size=0.2,
        random_state=SEED,
        stratify=df_all["label_name"]
    )
    print("\n✅ Using STRATIFIED 80/20 split on mixed data")
except ValueError as e:
    print("\n⚠️ Stratified split failed (some class too small). Using random split.")
    train_df_ft, test_df_new = train_test_split(
        df_all,
        test_size=0.2,
        random_state=SEED,
        shuffle=True
    )

print("Train (80%):", len(train_df_ft))
print("Test  (20%):", len(test_df_new))
print("\nTrain class counts:")
print(train_df_ft["label_name"].value_counts())
print("\nTest class counts:")
print(test_df_new["label_name"].value_counts())


Old set boxes: 254
New set boxes: 265

Mixed boxes total: 519
Class distribution (mixed):
label_name
sp02    230
sp01    100
sp04    100
sp03     89
Name: count, dtype: int64

✅ Using STRATIFIED 80/20 split on mixed data
Train (80%): 415
Test  (20%): 104

Train class counts:
label_name
sp02    184
sp04     80
sp01     80
sp03     71
Name: count, dtype: int64

Test class counts:
label_name
sp02    46
sp01    20
sp04    20
sp03    18
Name: count, dtype: int64


In [22]:
# ============================
# FIX: Always return same-size patches so DataLoader can stack batches
# ============================
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# --- REQUIRED: PATCH_SIZE must exist (or default) ---
PATCH_SIZE = globals().get("PATCH_SIZE", 100)

class ColonyPatchDatasetFixed(Dataset):
    """
    Fixes the batching crash:
    - Resolves image path robustly (image_path -> images_dir/name -> basename fallback)
    - Crops patch using crop_patch()
    - FORCES output patch to PATCH_SIZE x PATCH_SIZE (always)
    - Converts L -> RGB
    - Applies transform (if transform already includes resize, that's fine too)
    """
    def __init__(self, df, images_dir, class_to_idx, transform=None):
        self.df = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.class_to_idx = class_to_idx
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def _resolve_img_path(self, row):
        if "image_path" in self.df.columns:
            p = str(row["image_path"])
            if os.path.exists(p):
                return p

        name = str(row["image_name"])
        p2 = os.path.join(self.images_dir, name)
        if os.path.exists(p2):
            return p2

        p3 = os.path.join(self.images_dir, os.path.basename(name))
        if os.path.exists(p3):
            return p3

        raise FileNotFoundError(
            f"Could not find image:\n"
            f"  image_name: {name}\n"
            f"  image_path col: {row.get('image_path', None)}\n"
            f"  tried: {p2}\n"
            f"  tried: {p3}\n"
            f"  images_dir: {self.images_dir}"
        )

    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        img_path = self._resolve_img_path(r)

        # IMPORTANT: keep same as your pipeline (grayscale first)
        img_L = Image.open(img_path).convert("L")

        x = float(r["bbox_x"]); y = float(r["bbox_y"])
        w = float(r["bbox_width"]); h = float(r["bbox_height"])

        # ---- crop patch (whatever size it returns) ----
        patch_L = crop_patch(img_L, x, y, w, h)

        # ---- FORCE fixed size (this is the key fix) ----
        if patch_L.size != (PATCH_SIZE, PATCH_SIZE):
            patch_L = patch_L.resize((PATCH_SIZE, PATCH_SIZE))

        # L -> RGB
        patch_rgb = Image.merge("RGB", (patch_L, patch_L, patch_L))

        # Apply transforms
        if self.transform is not None:
            patch_rgb = self.transform(patch_rgb)

        # Label
        if "label_idx" in self.df.columns:
            y_idx = int(r["label_idx"])
        else:
            y_idx = int(self.class_to_idx[str(r["label_name"])])

        return patch_rgb, torch.tensor(y_idx, dtype=torch.long)




In [23]:
# ============================
# Rebuild datasets + loaders (same variable names as your training code)
# ============================
class_to_idx_ft = {c:i for i,c in enumerate(CLASSES_FT)}

ds_train_ft = ColonyPatchDatasetFixed(train_df_ft, IMAGES_DIR_FT, class_to_idx_ft, transform=train_tf)
ds_test_new = ColonyPatchDatasetFixed(test_df_new, IMAGES_DIR_TEST, class_to_idx_ft, transform=test_tf)

train_loader_ft = DataLoader(ds_train_ft, batch_size=64, shuffle=True,  num_workers=2, pin_memory=True)
test_loader_new = DataLoader(ds_test_new, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

print("✅ Train patches (A):", len(ds_train_ft))
print("✅ Test  patches (B):", len(ds_test_new))



✅ Train patches (A): 415
✅ Test  patches (B): 104


In [24]:
# ============================
# Model load (10-class -> replace head to FT classes)
# ============================
import torch.nn as nn
import torchvision.models as models

ckpt_10 = os.path.join(WORKDIR_PRETRAIN, "resnet18_pretrained_10class.pt")

model = models.resnet18(weights=None)
in_features = model.fc.in_features

# load 10-class checkpoint
model.fc = nn.Linear(in_features, 10)
model.load_state_dict(torch.load(ckpt_10, map_location="cpu"))

# replace head for in-house fine-tune
model.fc = nn.Linear(in_features, len(CLASSES_FT))
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()



In [25]:
# Directory to save fine-tuned models
WORKDIR = "/content/work_finetune"
os.makedirs(WORKDIR, exist_ok=True)

print("Fine-tuning models will be saved to:", WORKDIR)


Fine-tuning models will be saved to: /content/work_finetune


In [26]:
import os
import torch

# ============================
# Training (your exact loop)
# ============================
best_path = os.path.join(WORKDIR, "best_finetuned_on_newtest.pt")
best_acc = 0.0

def maybe_save_best(model, acc, best_acc, path):
    if acc > best_acc:
        torch.save(model.state_dict(), path)
        return acc, True
    return best_acc, False

print("\n===== PHASE 1: Head-only (features frozen) | Train=A (100%) → Test=B =====")

for p in model.parameters():
    p.requires_grad = False
for p in model.fc.parameters():
    p.requires_grad = True

optimizer1 = torch.optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer1, mode="max", factor=0.5, patience=2, threshold=1e-3, min_lr=1e-6
)

EPOCHS_HEAD = 10
for epoch in range(1, EPOCHS_HEAD + 1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader_ft, optimizer1, criterion)
    te_loss, te_acc = eval_one_epoch(model, test_loader_new, criterion)

    scheduler1.step(te_acc)
    lr1 = optimizer1.param_groups[0]["lr"]

    print(f"[HEAD] Epoch {epoch:02d} | lr {lr1:.2e} | train acc(A) {tr_acc:.3f} | test acc(B) {te_acc:.3f}")
    best_acc, saved = maybe_save_best(model, te_acc, best_acc, best_path)

print("Best test(B) acc after Phase 1:", best_acc)

print("\n===== PHASE 2: Full fine-tune (unfreeze all) | Train=A (100%) → Test=B =====")

for p in model.parameters():
    p.requires_grad = True

optimizer2 = torch.optim.Adam(model.parameters(), lr=3e-5, weight_decay=1e-4)
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer2, mode="max", factor=0.5, patience=2, threshold=1e-3, min_lr=1e-6
)

EPOCHS_FULL = 10
for epoch in range(1, EPOCHS_FULL + 1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader_ft, optimizer2, criterion)
    te_loss, te_acc = eval_one_epoch(model, test_loader_new, criterion)

    scheduler2.step(te_acc)
    lr2 = optimizer2.param_groups[0]["lr"]

    print(f"[FULL] Epoch {epoch:02d} | lr {lr2:.2e} | train acc(A) {tr_acc:.3f} | test acc(B) {te_acc:.3f}")
    best_acc, saved = maybe_save_best(model, te_acc, best_acc, best_path)

print("\n✅ FINAL Best test accuracy on NEW dataset (B):", best_acc)
print("✅ Best model saved at:", best_path)



===== PHASE 1: Head-only (features frozen) | Train=A (100%) → Test=B =====
[HEAD] Epoch 01 | lr 1.00e-03 | train acc(A) 0.390 | test acc(B) 0.462
[HEAD] Epoch 02 | lr 1.00e-03 | train acc(A) 0.499 | test acc(B) 0.558
[HEAD] Epoch 03 | lr 1.00e-03 | train acc(A) 0.545 | test acc(B) 0.587
[HEAD] Epoch 04 | lr 1.00e-03 | train acc(A) 0.593 | test acc(B) 0.596
[HEAD] Epoch 05 | lr 1.00e-03 | train acc(A) 0.593 | test acc(B) 0.615
[HEAD] Epoch 06 | lr 1.00e-03 | train acc(A) 0.614 | test acc(B) 0.635
[HEAD] Epoch 07 | lr 1.00e-03 | train acc(A) 0.634 | test acc(B) 0.635
[HEAD] Epoch 08 | lr 1.00e-03 | train acc(A) 0.614 | test acc(B) 0.644
[HEAD] Epoch 09 | lr 1.00e-03 | train acc(A) 0.684 | test acc(B) 0.644
[HEAD] Epoch 10 | lr 1.00e-03 | train acc(A) 0.680 | test acc(B) 0.673
Best test(B) acc after Phase 1: 0.6730769184919504

===== PHASE 2: Full fine-tune (unfreeze all) | Train=A (100%) → Test=B =====
[FULL] Epoch 01 | lr 3.00e-05 | train acc(A) 0.694 | test acc(B) 0.702
[FULL] Epoch 0