In [2]:
# Synthetic PMAT Patch Generator — **uses your project root**
# This notebook **uses your folder layout and variable names** to generate *training-only* synthetic patches for PMAT classes, with optional ControlNet (Seg) conditioning. It avoids leakage by only using **train annotations**.

# **Outputs (relative to your repo root):**
# - `synthetic_patches/<Class>/...` — New images
# - `artifacts/synth_manifest.json` — Provenance (class, source patch, seed, method)
# - Optional preview grids

# > Toggle `MODE = "DIFFUSION"` to use Stable Diffusion + ControlNet (seg). Keep it `"MOCK"` if you just want the wiring verified first.

In [3]:
# Cell 1 — Config & root discovery
import os, json, random, glob
from pathlib import Path
from typing import Dict
from PIL import Image, ImageOps, ImageFilter

# --- Choose projection folder to use ---
PROJECTION = "mean_projection"   # change to "max_projection" if you prefer
project_path = "."
repo_path = os.path.join(project_path, "borg-main")
data_root_dir = os.path.join(repo_path, f"data/{PROJECTION}")
train_json_path = os.path.join(data_root_dir, "organoid_coco_train.json")
val_json_path   = os.path.join(data_root_dir, "organoid_coco_val.json")
train_images_dir = os.path.join(data_root_dir, "images", "train")
val_images_dir   = os.path.join(data_root_dir, "images", "val")

# Output dirs
PATCH_CACHE_TRAIN = "patch_cache_train"   # new cache for cropped patches (train only)
MASKS_DIR         = "processed_masks_train"
SYNTH_DIR         = "synthetic_patches"
ARTIFACTS         = "artifacts"

IMG_SIZE = 96
CLASSES = ["Prophase","Metaphase","Anaphase","Telophase"]

for p in [PATCH_CACHE_TRAIN, MASKS_DIR, SYNTH_DIR, ARTIFACTS]:
    Path(p).mkdir(parents=True, exist_ok=True)

print("Using projection:", PROJECTION)
print("Repo path:", os.path.abspath(repo_path))
print("Data root:", os.path.abspath(data_root_dir))
print("Train images:", os.path.abspath(train_images_dir))
print("Val images:", os.path.abspath(val_images_dir))

Using projection: mean_projection
Repo path: /home/sebas_dev_linux/projects/snn_project/Diffusion_SNN_ML_v0/borg-main
Data root: /home/sebas_dev_linux/projects/snn_project/Diffusion_SNN_ML_v0/borg-main/data/mean_projection
Train images: /home/sebas_dev_linux/projects/snn_project/Diffusion_SNN_ML_v0/borg-main/data/mean_projection/images/train
Val images: /home/sebas_dev_linux/projects/snn_project/Diffusion_SNN_ML_v0/borg-main/data/mean_projection/images/val


In [4]:
# Cell 2 — Load train/val annotations (COCO JSON)
with open(train_json_path, "r") as f:
    train_coco = json.load(f)
with open(val_json_path, "r") as f:
    val_coco = json.load(f)

categories_map: Dict[int, str] = {c["id"]: c["name"] for c in train_coco["categories"]}
train_id_to_filename = {img["id"]: img["file_name"] for img in train_coco["images"]}
val_id_to_filename   = {img["id"]: img["file_name"] for img in val_coco["images"]}

train_annotations = train_coco["annotations"]   # we will synthesize from these only
val_annotations   = val_coco["annotations"]

print("Classes:", categories_map)
print("Train annotations:", len(train_annotations))
print("Val annotations:", len(val_annotations))

Classes: {1: 'Prophase', 2: 'Metaphase', 3: 'Anaphase', 4: 'Telophase'}
Train annotations: 556
Val annotations: 181


In [5]:
# Cell 3 — Cache train patches directly from COCO bboxes (robust path resolution)
from collections import defaultdict
from pathlib import Path
from PIL import Image

# Ensure cache dirs exist
for cls in CLASSES:
    Path(PATCH_CACHE_TRAIN, cls).mkdir(parents=True, exist_ok=True)

def resolve_image_path(file_name: str) -> Path | None:
    """Resolve image path regardless of whether file_name already contains images/train/…"""
    p = Path(file_name)
    cands = []
    if p.is_absolute():
        cands.append(p)
    # Typical COCO: file_name already 'images/train/xxx.png' relative to the projection root
    cands.append(Path(data_root_dir) / file_name)
    # Fallbacks: try train-images dir and basename-only
    cands.append(Path(train_images_dir) / file_name)
    cands.append(Path(train_images_dir) / p.name)
    for c in cands:
        if c.exists():
            return c
    return None

def _si(v):  # safe int
    try:
        return int(round(float(v)))
    except Exception:
        return int(v)

PAD_FRAC = 0.10  # small padding around bbox
ann_to_patch = {}
class_to_train_ann_ids = defaultdict(list)
missing_imgs, skipped_bad_bbox = 0, 0

for ann in train_annotations:
    cls = categories_map[ann["category_id"]]
    img_id = ann["image_id"]
    # Look up file_name via the IMAGES table
    file_name = train_id_to_filename.get(img_id)
    if file_name is None:
        missing_imgs += 1
        continue
    img_path = resolve_image_path(file_name)
    if img_path is None:
        missing_imgs += 1
        continue

    img = Image.open(img_path).convert("RGB")
    W, H = img.width, img.height

    # COCO bbox: [x, y, width, height] (xywh, pixels)
    bb = ann.get("bbox")
    if not bb or len(bb) != 4:
        skipped_bad_bbox += 1
        continue
    x, y, bw, bh = float(bb[0]), float(bb[1]), float(bb[2]), float(bb[3])
    if bw <= 1 or bh <= 1:
        skipped_bad_bbox += 1
        continue

    # pad and clamp to image bounds
    x0 = max(0, _si(x - PAD_FRAC * bw))
    y0 = max(0, _si(y - PAD_FRAC * bh))
    x1 = min(W, _si(x + bw + PAD_FRAC * bw))
    y1 = min(H, _si(y + bh + PAD_FRAC * bh))
    if x1 <= x0 or y1 <= y0:
        skipped_bad_bbox += 1
        continue

    patch = img.crop((x0, y0, x1, y1)).resize((IMG_SIZE, IMG_SIZE), Image.Resampling.LANCZOS)
    outp = Path(PATCH_CACHE_TRAIN) / cls / f"patch_{ann['id']}.png"
    patch.save(outp)
    ann_to_patch[ann["id"]] = str(outp)
    class_to_train_ann_ids[cls].append(ann["id"])

print("Cached train patches:", len(ann_to_patch))
print("Per-class cached:", {c: len(v) for c, v in class_to_train_ann_ids.items()})
print("Missing source images:", missing_imgs, "| Skipped (bad/degenerate bbox):", skipped_bad_bbox)

# Optional quick sanity check: print a couple of resolved examples
some = next(iter(train_annotations), None)
if some:
    fn = train_id_to_filename.get(some["image_id"])
    print("Example file_name from JSON:", fn)
    print("Resolved path:", resolve_image_path(fn))


Cached train patches: 556
Per-class cached: {'Prophase': 282, 'Metaphase': 146, 'Anaphase': 69, 'Telophase': 59}
Missing source images: 0 | Skipped (bad/degenerate bbox): 0
Example file_name from JSON: images/train/phase_1_new_v10_frame_02.png
Resolved path: borg-main/data/mean_projection/images/train/phase_1_new_v10_frame_02.png


In [6]:
# Cell 4 — Build masks from processed patches (robust Otsu + closing). 
# If you have Cellpose/Ilastik masks, plug them here instead.
import numpy as np
from PIL import ImageFilter, Image

MASKS_DIR = 'processed_masks_train'  # predicted/derived masks for train only
Path(MASKS_DIR).mkdir(parents=True, exist_ok=True)

def mask_from_patch(png_path: str) -> Image.Image:
    img = Image.open(png_path).convert('RGB').resize((IMG_SIZE, IMG_SIZE), Image.Resampling.LANCZOS)
    # Use red+green emphasis (your dataset is dual-channel fluorescence; nucleus often brighter in R)
    arr = np.array(img).astype(np.uint8)
    # Weighted grayscale favoring R, then Otsu threshold
    gray = (0.6*arr[...,0] + 0.3*arr[...,1] + 0.1*arr[...,2]).astype(np.uint8)
    # Otsu threshold
    th = gray.mean() + 0.5*gray.std()   # fallback if cv2 not available; simple adaptive rule
    m = (gray > th).astype(np.uint8)*255
    m_img = Image.fromarray(m, mode='L').filter(ImageFilter.MaxFilter(3)).filter(ImageFilter.MinFilter(3))
    return m_img

# Write masks for train patches
for cls in CLASSES:
    (Path(MASKS_DIR)/cls).mkdir(parents=True, exist_ok=True)

wrote = 0
for cls, ann_ids in class_to_train_ann_ids.items():
    for ann_id in ann_ids:
        src = ann_to_patch[ann_id]
        outp = Path(MASKS_DIR)/cls/f"mask_{ann_id}.png"
        if not outp.exists():
            m = mask_from_patch(src)
            m.save(outp)
            wrote += 1
print("Masks written:", wrote)

  m_img = Image.fromarray(m, mode='L').filter(ImageFilter.MaxFilter(3)).filter(ImageFilter.MinFilter(3))


Masks written: 556


In [7]:
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=dtype)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    controlnet=controlnet,
    safety_checker=None,
    torch_dtype=dtype,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()   # VRAM saver (use this instead of xFormers)
pipe.enable_vae_slicing()
pipe.to(device)


  from .autonotebook import tqdm as notebook_tqdm
Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading pipeline components...: 100%|██████████| 6/6 [00:03<00:00,  1.72it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


StableDiffusionControlNetPipeline {
  "_class_name": "StableDiffusionControlNetPipeline",
  "_diffusers_version": "0.35.2",
  "_name_or_path": "runwayml/stable-diffusion-v1-5",
  "controlnet": [
    "diffusers",
    "ControlNetModel"
  ],
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": true,
  "safety_checker": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "UniPCMultistepScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [8]:
# Cell 5 — Diffusion scaffolding (ControlNet Seg) OR MOCK texture generator
MODE = "DIFFUSION" #"MOCK"   # change to "DIFFUSION" to enable Stable Diffusion + ControlNet(seg)

PROMPTS = {
    "Prophase":  "fluorescence microscopy, nucleus in prophase, confocal, cellular texture",
    "Metaphase": "fluorescence microscopy, metaphase, chromosomes aligned at the equatorial plate",
    "Anaphase":  "fluorescence microscopy, anaphase, chromosomes to opposite poles",
    "Telophase": "fluorescence microscopy, telophase, two daughter nuclei forming",
}
NEGATIVE = "cartoon, text, watermark, artifacts, blurry, out of focus"

pipe = None
device = None

def try_load_diffusion():
    try:
        from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
        import torch
    except Exception as e:
        print("Diffusers not available; remaining in MOCK mode.\\n", e)
        return None, None
    device = "cuda" if torch.cuda.is_available() else "cpu"
    controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg")
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        controlnet=controlnet,
        safety_checker=None,
    )
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to(device)
    print(f"Loaded SD1.5 + ControlNet(seg) on {device}")
    return pipe, device

def generate_with_controlnet(pipe, device, rgb_img: Image.Image, seg_mask: Image.Image, prompt: str,
                             strength=0.55, guidance_scale=6.5, steps=24, seed=None):
    import torch
    if seed is None:
        seed = random.randint(0, 1_000_000)
    g = torch.Generator(device=device).manual_seed(seed)
    rgb512 = rgb_img.resize((512,512), Image.Resampling.BICUBIC)
    seg512 = seg_mask.resize((512,512), Image.Resampling.NEAREST).convert('RGB')
    out = pipe(
        prompt=prompt, image=rgb512, control_image=seg512,
        negative_prompt=NEGATIVE, generator=g,
        num_inference_steps=steps, guidance_scale=guidance_scale, strength=strength,
    )
    im = out.images[0].resize((IMG_SIZE, IMG_SIZE), Image.Resampling.LANCZOS)
    return im, {"seed": seed, "steps": steps, "guidance_scale": guidance_scale, "strength": strength}

# MOCK generator — no internet/GPU required
import numpy as np
def mock_generate_from_mask(mask_img: Image.Image, seed=None) -> Image.Image:
    r = random.Random(seed or random.randint(0, 999_999))
    m = mask_img if mask_img.mode == "L" else mask_img.convert("L")
    base = Image.new("RGB", m.size, (r.randint(0,25),)*3)
    # noise texture
    tex = Image.effect_noise(m.size, r.randint(30,90)).convert("L")
    tex = tex.filter(ImageFilter.GaussianBlur(radius=r.uniform(0.8, 1.6)))
    color = Image.new("RGB", m.size, (r.randint(80,240), r.randint(80,240), r.randint(80,240)))
    color = Image.blend(color, Image.new("RGB", m.size, (r.randint(0,40),)*3), r.uniform(0.2, 0.5))
    arr_c = np.array(color).astype(np.int16)
    arr_t = np.array(tex).astype(np.float32) / 255.0
    arr_c = np.clip(arr_c * (0.7 + 0.6*arr_t[...,None]), 0, 255).astype(np.uint8)
    tex_color = Image.fromarray(arr_c, mode="RGB")
    m_blur = m.filter(ImageFilter.GaussianBlur(radius=0.8))
    out = Image.composite(tex_color, base, m_blur)
    # gentle halo
    edge = m.filter(ImageFilter.FIND_EDGES).filter(ImageFilter.GaussianBlur(radius=0.7))
    edge_alpha = ImageOps.autocontrast(edge).point(lambda x: int(x*0.15))
    overlay_rgba = Image.new("RGBA", m.size, (255,255,255,0)); overlay_rgba.putalpha(edge_alpha)
    out_rgba = out.convert("RGBA")
    final = Image.alpha_composite(out_rgba, overlay_rgba).convert("RGB")
    return final

if MODE.upper().startswith("DIFF"):
    pipe, device = try_load_diffusion()
else:
    print("MODE=MOCK — using texture-based generator for quick wiring test.")

Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 29.67it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


Loaded SD1.5 + ControlNet(seg) on cuda


In [None]:
# Cell 6 — Synthesize per class (train-only), with target counts and manifest
from collections import Counter
from pathlib import Path
import json

# Choose a per-class target: bring all classes up to the max train count
train_counts = {cls: len(class_to_train_ann_ids[cls]) for cls in CLASSES}
target = max(train_counts.values())
print("Train counts:", train_counts, "=> target per class:", target)

manifest = []
for cls in CLASSES:
    outd = Path(SYNTH_DIR)/cls
    outd.mkdir(parents=True, exist_ok=True)

# Iterate ann_ids, generate as many as needed to reach target
for cls, ann_ids in class_to_train_ann_ids.items():
    need = max(0, target - len(ann_ids))
    if need == 0:
        print(f"{cls}: already at target, skipping generation.")
        continue
    print(f"{cls}: generating {need} images…")
    # round-robin over existing train patches
    k = 0
    while k < need:
        for ann_id in ann_ids:
            if k >= need: break
            src_patch = ann_to_patch[ann_id]
            mask_path = Path(MASKS_DIR)/cls/f"mask_{ann_id}.png"
            if not Path(src_patch).exists() or not mask_path.exists():
                continue
            rgb = Image.open(src_patch).convert('RGB').resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BICUBIC)
            msk = Image.open(mask_path).convert('L').resize((IMG_SIZE, IMG_SIZE), Image.Resampling.NEAREST)
            if MODE.upper().startswith("DIFF") and pipe is not None:
                out_img, meta = generate_with_controlnet(pipe, device, rgb, msk, PROMPTS[cls])
                method = "controlnet_seg"
                seed = meta["seed"]
            else:
                seed = 10_000 + k
                out_img = mock_generate_from_mask(msk, seed=seed)
                method = "mock_from_mask"
            out_name = f"{cls.lower()}_synth_{ann_id}_{k:04d}.png"
            out_path = Path(SYNTH_DIR)/cls/out_name
            out_img.save(out_path)
            manifest.append({
                "class": cls,
                "file": str(out_path),
                "source_patch": src_patch,
                "source_mask": str(mask_path),
                "method": method,
                "seed": seed,
                "img_size": IMG_SIZE,
            })
            k += 1

# Save manifest
man_path = Path(ARTIFACTS)/"synth_manifest.json"
with open(man_path, "w") as f:
    json.dump(manifest, f, indent=2)
print("Wrote manifest:", man_path, "| total synth:", len(manifest))

Train counts: {'Prophase': 282, 'Metaphase': 146, 'Anaphase': 69, 'Telophase': 59} => target per class: 282
Prophase: already at target, skipping generation.
Metaphase: generating 136 images…


100%|██████████| 24/24 [08:19<00:00, 20.82s/it]
100%|██████████| 24/24 [08:17<00:00, 20.73s/it]
100%|██████████| 24/24 [08:17<00:00, 20.72s/it]
100%|██████████| 24/24 [08:16<00:00, 20.69s/it]
100%|██████████| 24/24 [08:26<00:00, 21.11s/it]
100%|██████████| 24/24 [08:19<00:00, 20.80s/it]
100%|██████████| 24/24 [08:09<00:00, 20.42s/it]
100%|██████████| 24/24 [08:08<00:00, 20.35s/it]
100%|██████████| 24/24 [08:08<00:00, 20.36s/it]
100%|██████████| 24/24 [10:30<00:00, 26.29s/it]
100%|██████████| 24/24 [11:49<00:00, 29.57s/it]
100%|██████████| 24/24 [11:44<00:00, 29.34s/it]
100%|██████████| 24/24 [12:58<00:00, 32.44s/it]
100%|██████████| 24/24 [22:48<00:00, 57.03s/it]
 25%|██▌       | 6/24 [05:38<16:39, 55.52s/it]

In [None]:
# Cell 7 — Light de-dup (aHash) and class counts report
import numpy as np

def average_hash(img: Image.Image, hash_size=8):
    g = ImageOps.grayscale(img).resize((hash_size, hash_size), Image.Resampling.LANCZOS)
    arr = np.array(g).astype(np.float32)
    th = arr.mean()
    bits = (arr > th).astype(np.uint8).flatten()
    h = 0
    for b in bits:
        h = (h << 1) | int(b)
    return h

def hamming(a, b):
    return bin(a ^ b).count("1")

removed = 0
for cls in CLASSES:
    files = sorted((Path(SYNTH_DIR)/cls).glob("*.png"))
    seen = []
    for fp in files:
        img = Image.open(fp)
        ah = average_hash(img)
        if any(hamming(ah, s) <= 2 for s in seen):
            fp.unlink(missing_ok=True)
            removed += 1
        else:
            seen.append(ah)
print("Removed near-duplicates:", removed)

counts = {cls: len(list((Path(SYNTH_DIR)/cls).glob('*.png'))) for cls in CLASSES}
print("Synthetic counts (after dedup):", counts)

In [None]:
# Cell 8 — Preview grid (first few per class)
import matplotlib.pyplot as plt

fig, axes = plt.subplots(len(CLASSES), 4, figsize=(6, 6))
for r, cls in enumerate(CLASSES):
    files = sorted(glob.glob(str((Path(SYNTH_DIR) / cls / '*.png').resolve())))[:4]
    for c, f in enumerate(files):
        ax = axes[r, c]
        ax.imshow(Image.open(f))
        ax.set_title(f"{cls} synth", fontsize=8)
        ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Cell 9 — Zip for portability
import zipfile, os
zip_path = Path(ARTIFACTS) / "synthetic_patches.zip"
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    for cls in CLASSES:
        for f in (Path(SYNTH_DIR)/cls).glob("*.png"):
            zf.write(f, arcname=f"synthetic_patches/{cls}/{f.name}")
    zf.write(Path(ARTIFACTS)/"synth_manifest.json", arcname="artifacts/synth_manifest.json")
print("Zipped:", zip_path)