In [None]:
# Install the Drive-download helper
!pip install --quiet gdown

# Replace FILE_ID below with your actual file ID from step 1
FILE_ID = "1DltzLdv2oKEYbeisgh9c4KkdCJSE-S_C"
ZIP_NAME = "Blastocyst_Dataset.zip"



In [None]:
# Download via gdown
import gdown
gdown.download(f"https://drive.google.com/uc?id={FILE_ID}", ZIP_NAME, quiet=False)
!unzip -q Blastocyst_Dataset.zip -d blastocyst_dataset

Downloading...
From (original): https://drive.google.com/uc?id=1DltzLdv2oKEYbeisgh9c4KkdCJSE-S_C
From (redirected): https://drive.google.com/uc?id=1DltzLdv2oKEYbeisgh9c4KkdCJSE-S_C&confirm=t&uuid=3754caaf-754a-4d4f-b54f-1420aeb81af2
To: /content/Blastocyst_Dataset.zip
100%|██████████| 625M/625M [00:06<00:00, 94.3MB/s]


In [None]:
data_dir  = "/content/blastocyst_dataset/Blastocyst_Dataset"
train_csv = f"{data_dir}/gardner_criteria_train.csv"
test_csv  = f"{data_dir}/gardner_criteria_test.csv"

# Quick sanity check
import os
print("Found images:", len(os.listdir(f"{data_dir}/Images")))
print("CSV files:", os.listdir(data_dir))

Found images: 2344
CSV files: ['gardner_criteria_train.csv', 'Gardner_test_gold.xlsx', 'evaluate_prediction_results.py', 'Clincial_annotations.csv', 'Images', 'prediction_xception.csv', 'Gardner_test_gold_onlyGardnerScores.csv', 'gardner_criteria_test.csv', '.DS_Store']


In [None]:
import os
import cv2
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T


# Load CSV annotations
train_df = pd.read_csv(train_csv, sep=';')
test_df = pd.read_csv(test_csv)

print(train_df.columns)
print(test_df.columns)

Index(['Image', 'EXP_label', 'ICM_label', 'TE_label'], dtype='object')
Index(['Image', 'EXP_label', 'ICM_label', 'TE_label', 'EXP_Agreement',
       'ICM_Agreement', 'TE_Agreement', 'EXP_Agreement_desc',
       'ICM_Agreement_desc', 'TE_Agreement_desc', 'Unnamed: 10'],
      dtype='object')


In [None]:

# Example structure of CSV: columns: filename, EXP_label, ICM_label, TE_label.
# Combine train_df and test_df then split into train/val if needed:
from sklearn.model_selection import train_test_split


#from sklearn.model_selection import train_test_split

# 1. Create a combined label string, e.g. "3_A_B"
train_df["combo_label"] = (
    train_df["EXP_label"].astype(str)
    + "_"
    + train_df["ICM_label"].astype(str)
    + "_"
    + train_df["TE_label"].astype(str)
)

# 2. Drop any rare combos that appear only once (so stratification can work)
combo_counts = train_df["combo_label"].value_counts()
valid_combos = combo_counts[combo_counts >= 2].index
train_df = train_df[train_df["combo_label"].isin(valid_combos)].reset_index(drop=True)

# 3. Perform stratified split on the new 1-D combo_label
train_df, val_df = train_test_split(
    train_df,
    test_size=0.1,
    stratify=train_df["combo_label"],
    random_state=42
)

# 4. (Optional) Drop the helper column now that split is done
train_df = train_df.drop(columns="combo_label")
val_df   = val_df.drop(columns="combo_label")




# Define image transformations for preprocessing
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess_transforms = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    normalize
])

def crop_to_embryo(pil_img):
    """Crop the image to the embryo region using thresholding."""
    img_gray = np.array(pil_img.convert("L"))
    # Otsu threshold to separate background
    _, mask = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    # Find largest contour
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        # Get bounding box of largest contour
        largest_contour = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(largest_contour)
        pil_img = pil_img.crop((x, y, x+w, y+h))
    return pil_img

class BlastocystDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, crop=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.crop = crop
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['Image'])
        image = Image.open(img_path).convert("RGB")
        # Crop to embryo region
        if self.crop:
            image = crop_to_embryo(image)
        # Apply preprocessing transforms
        if self.transform:
            image = self.transform(image)
        # Get labels
        exp_label = row['EXP_label']          # e.g. 1-5
        icm_label = row['ICM_label']          # e.g. 'A','B','C','ND'
        te_label  = row['TE_label']           # e.g. 'A','B','C','ND'
        # Convert labels to numeric classes
        # Map 'A','B','C','Not defined' -> 0,1,2,3 (for example), and EXP 1-5 -> 0-4
        icm_map = {'A':0, 'B':1, 'C':2, 'Not defined':3}
        te_map  = {'A':0, 'B':1, 'C':2, 'Not defined':3}
        exp = exp_label - 1  # make 0-indexed (1->0,...5->4)
        icm = icm_map.get(icm_label, 3)
        te  = te_map.get(te_label, 3)
        # Return image and labels
        return image, exp, icm, te

# Create Dataset objects
train_data = BlastocystDataset(train_df, img_dir=os.path.join(data_dir,"Images"), transform=preprocess_transforms)
val_data   = BlastocystDataset(val_df,   img_dir=os.path.join(data_dir,"Images"), transform=preprocess_transforms)
test_data  = BlastocystDataset(test_df,  img_dir=os.path.join(data_dir,"Images"), transform=preprocess_transforms)

# DataLoaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128

# 1) Installs (if not already)
!pip install -q diffusers accelerate transformers

# 2) Imports
import os, gc, math, time
import torch
import torchvision.transforms as T
from diffusers import UNet2DModel, DDPMScheduler
from google.colab import drive

# 3) Mount Drive (if not already)
drive.mount('/content/drive', force_remount=True)
checkpoint_dir = '/content/drive/MyDrive/diffusion_checkpoints_without_crop'
os.makedirs(checkpoint_dir, exist_ok=True)

# 4) Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.cuda.empty_cache(); gc.collect()

# 5) Memory helper
def print_mem(prefix=""):
    if torch.cuda.is_available():
        print(prefix,
              "alloc:", f"{torch.cuda.memory_allocated()/1024**2:.1f} MiB;",
              "reserved:", f"{torch.cuda.memory_reserved()/1024**2:.1f} MiB;",
              "max_alloc:", f"{torch.cuda.max_memory_allocated()/1024**2:.1f} MiB")
    else:
        print(prefix, "no cuda")

# 6) Reduced resolution and smaller UNet to save memory
image_size = 224
batch_size = 2

model = UNet2DModel(
    sample_size=image_size,
    in_channels=3, out_channels=3,
    layers_per_block=2,
    block_out_channels=(32, 64, 128, 128),
    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")
).to(device)

noise_scheduler = DDPMScheduler(num_train_timesteps=500)  # fewer timesteps -> faster, less mem ideally

# 7) Diffusers memory helpers
# attention slicing reduces peak memory at the cost of time
try:
    model.enable_attention_slicing()    # safe if supported
    print("Enabled attention slicing")
except Exception as e:
    print("enable_attention_slicing not available:", e)

# gradient checkpointing if supported by this UNet implementation (saves activations by recomputing)
if hasattr(model, "enable_gradient_checkpointing"):
    try:
        model.enable_gradient_checkpointing()
        print("Enabled model gradient checkpointing")
    except Exception as e:
        print("enable_gradient_checkpointing failed:", e)
else:
    print("No gradient checkpointing API; consider torch.utils.checkpoint for custom modules")

# 8) Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 9) Mixed precision scaler
scaler = torch.cuda.amp.GradScaler()

# 10) Data loader:
transform_resize = T.Resize((image_size, image_size))

print("Using batch_size:", batch_size)
print_mem("Before training:")

# 11) Training loop (mixed precision + stepping)
num_epochs = 50   # reduce epochs for quick test; increase later
model.train()
for epoch in range(num_epochs):
    epoch_start = time.time()
    running_loss = 0.0
    n_batches = 0
    for batch in train_loader:
        # Expected shape: (B, C, H, W) or (PIL images)
        imgs = batch[0] if isinstance(batch, (list, tuple)) else batch
        # If imgs are PIL Images, convert here: imgs = torch.stack([preprocess(img) for img in imgs])
        # Resize if needed
        if isinstance(imgs, torch.Tensor):
            imgs = T.functional.resize(imgs, [image_size, image_size])
        else:
            # safe fallback if imgs are PIL in a list:
            imgs = torch.stack([T.ToTensor()(transform_resize(img)) for img in imgs])
        imgs = imgs.to(device, dtype=torch.float32)

        # sample timesteps & add noise
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (imgs.size(0),), device=device).long()
        noise = torch.randn_like(imgs)

        with torch.cuda.amp.autocast():   # mixed precision
            noisy_imgs = noise_scheduler.add_noise(imgs, noise, timesteps)
            out = model(noisy_imgs, timesteps)  # diffusers UNet returns a ModelOutput-like with .sample
            # handle different diffusers versions: out may be a tuple or ModelOutput
            noise_pred = out.sample if hasattr(out, "sample") else out[0]
            loss = torch.nn.functional.mse_loss(noise_pred, noise)

        scaler.scale(loss).backward()
        # optional gradient clipping (helps stability)
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        running_loss += loss.item()
        n_batches += 1

        # print occasional memory usage
        if n_batches % 50 == 0:
            print_mem(f"Epoch {epoch+1} batch {n_batches} ->")
            # flush to Drive occasionally (optional)
    epoch_time = time.time() - epoch_start
    avg_loss = running_loss / max(1, n_batches)
    print(f"Epoch {epoch+1}/{num_epochs} done. avg_loss={avg_loss:.5f}, time={epoch_time:.1f}s")
    # Save intermediate checkpoint to Drive (safe)
    if (epoch+1) % 5 == 0:
        cpath = os.path.join(checkpoint_dir, f"unet_epoch{epoch+1}.pth")
        try:
            model.save_pretrained(checkpoint_dir)          # diffusers save helper
            noise_scheduler.save_pretrained(checkpoint_dir)
            print(f"Saved checkpoint at epoch {epoch+1} to {checkpoint_dir}")
        except Exception as e:
            print("Checkpoint save failed:", e)
    # free some memory periodically
    torch.cuda.empty_cache()
    gc.collect()

print("Training finished.")
print_mem("After training:")
# Final save
model.save_pretrained(checkpoint_dir)
noise_scheduler.save_pretrained(checkpoint_dir)
print(f"✅ Saved final model and scheduler to {checkpoint_dir}")

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128
Mounted at /content/drive
Using device: cuda
enable_attention_slicing not available: 'UNet2DModel' object has no attribute 'enable_attention_slicing'
Enabled model gradient checkpointing
Using batch_size: 2
Before training: alloc: 24.9 MiB; reserved: 40.0 MiB; max_alloc: 24.9 MiB


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():   # mixed precision


Epoch 1 batch 50 -> alloc: 173.7 MiB; reserved: 26078.0 MiB; max_alloc: 7550.1 MiB
Epoch 1/50 done. avg_loss=0.52846, time=55.4s
Epoch 2 batch 50 -> alloc: 173.7 MiB; reserved: 14038.0 MiB; max_alloc: 7550.1 MiB
Epoch 2/50 done. avg_loss=0.17305, time=53.2s
Epoch 3 batch 50 -> alloc: 173.7 MiB; reserved: 22138.0 MiB; max_alloc: 7550.1 MiB
Epoch 3/50 done. avg_loss=0.11341, time=53.1s
Epoch 4 batch 50 -> alloc: 173.7 MiB; reserved: 30278.0 MiB; max_alloc: 7550.1 MiB
Epoch 4/50 done. avg_loss=0.09129, time=52.9s
Epoch 5 batch 50 -> alloc: 173.7 MiB; reserved: 6658.0 MiB; max_alloc: 7550.1 MiB
Epoch 5/50 done. avg_loss=0.07998, time=53.0s
Saved checkpoint at epoch 5 to /content/drive/MyDrive/diffusion_checkpoints_without_crop
Epoch 6 batch 50 -> alloc: 173.7 MiB; reserved: 20318.0 MiB; max_alloc: 7550.1 MiB
Epoch 6/50 done. avg_loss=0.06765, time=52.9s
Epoch 7 batch 50 -> alloc: 173.7 MiB; reserved: 13218.0 MiB; max_alloc: 7550.1 MiB
Epoch 7/50 done. avg_loss=0.06484, time=53.2s
Epoch 8 b

In [None]:
# Generate & save diffusion images one-by-one (memory-friendly)
# Assumes:
# - Drive is mounted
# - checkpoint_dir contains your pretrained UNet + scheduler
# - diffusers, torch already installed/imported

import os, time, gc
from PIL import Image
import numpy as np
import torch
from google.colab import drive
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline

# --- Config ---
drive.mount('/content/drive', force_remount=True)
checkpoint_dir = '/content/drive/MyDrive/diffusion_checkpoints_without_crop'   # your checkpoint folder
OUT_DIR = "/content/drive/MyDrive/diffusion_generated_images_latest"           # where to save images
os.makedirs(OUT_DIR, exist_ok=True)

NUM_IMAGES = 500        # total images to generate
STEPS = 500             # num_inference_steps for sampling (50-100 is typically ok). Set to 500 if you want highest quality.
PRINT_EVERY = 10        # progress print frequency
SLEEP_AFTER = 0.05      # small pause to let GPU driver breathe (optional)
RETRY_ON_ERR = 2        # retry attempts for transient errors

# --- Device & model load ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
model = UNet2DModel.from_pretrained(checkpoint_dir).to(device)
noise_scheduler = DDPMScheduler.from_pretrained(checkpoint_dir)
pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler).to(device)
pipeline.set_progress_bar_config(disable=False)

# --- Helper: convert output into PIL.Image if needed ---
def to_pil(img_obj):
    # If already PIL
    if isinstance(img_obj, Image.Image):
        return img_obj
    # torch tensor (C,H,W) or (B,C,H,W)
    if isinstance(img_obj, torch.Tensor):
        arr = img_obj.detach().cpu().numpy()
        if arr.ndim == 4:
            arr = arr[0]  # take first in batch
        # CHW -> HWC
        if arr.shape[0] in (1,3):
            arr = np.transpose(arr, (1,2,0))
        # Normalize floats -> 0..255
        if arr.dtype in (np.float32, np.float64):
            mn, mx = arr.min(), arr.max()
            if mx - mn < 1e-6:
                arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8)
            else:
                arr = (arr - mn) / (mx - mn + 1e-8)
                arr = (255.0 * arr).astype(np.uint8)
        return Image.fromarray(arr)
    # numpy array
    if isinstance(img_obj, np.ndarray):
        arr = img_obj
        if arr.ndim == 3 and arr.shape[0] in (1,3):
            arr = np.transpose(arr, (1,2,0))
        if arr.dtype != np.uint8:
            mn, mx = arr.min(), arr.max()
            arr = (arr - mn) / (mx - mn + 1e-8)
            arr = (255.0 * arr).astype(np.uint8)
        return Image.fromarray(arr)
    # fallback: try converting via PIL
    return Image.fromarray(np.array(img_obj))

# --- Resume logic: if files exist, continue where left off ---
existing = sorted([f for f in os.listdir(OUT_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg"))])
start_idx = len(existing) + 1
print(f"Found {len(existing)} existing images in OUT_DIR, will start at index {start_idx}")

# --- Generation loop (one by one) ---
start_time = time.time()
generated = 0
for i in range(start_idx, start_idx + NUM_IMAGES):
    attempt = 0
    while attempt <= RETRY_ON_ERR:
        try:
            out = pipeline(num_inference_steps=STEPS)  # generate one image
            img = out.images[0]                        # usually a PIL.Image
            pil_img = to_pil(img)
            fname = os.path.join(OUT_DIR, f"synthetic_{i:04d}.png")
            pil_img.save(fname)
            generated += 1

            # logging
            elapsed = time.time() - start_time
            avg = elapsed / generated
            if generated % PRINT_EVERY == 0 or generated == NUM_IMAGES:
                print(f"[{generated}/{NUM_IMAGES}] Saved {fname}  (avg {avg:.2f}s/img, elapsed {elapsed:.1f}s)")

            # small memory housekeeping
            del out, img, pil_img
            torch.cuda.empty_cache()
            gc.collect()
            time.sleep(SLEEP_AFTER)
            break  # success -> exit retry loop

        except Exception as e:
            attempt += 1
            print(f"Error generating image {i} (attempt {attempt}/{RETRY_ON_ERR}): {e}")
            # try a small recovery: clear cache, sleep, and retry with fewer steps if repeated failures
            torch.cuda.empty_cache()
            gc.collect()
            time.sleep(1.0)
            if attempt > RETRY_ON_ERR:
                print(f"Skipping image {i} after {RETRY_ON_ERR} retries.")
                break
            # Optionally reduce steps on retry (uncomment if desired)
            # STEPS = max(20, int(STEPS * 0.8))

total_time = time.time() - start_time
print(f"\n✅ Done. Generated {generated} new images. Saved to: {OUT_DIR}")
print(f"Total time: {total_time:.1f}s, avg {total_time/max(1,generated):.2f}s/img")

Mounted at /content/drive
Device: cuda
Found 0 existing images in OUT_DIR, will start at index 1


[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[10/500] Saved /content/drive/MyDrive/diffusion_generated_images_latest/synthetic_0010.png  (avg 29.74s/img, elapsed 297.4s)


[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[20/500] Saved /content/drive/MyDrive/diffusion_generated_images_latest/synthetic_0020.png  (avg 29.74s/img, elapsed 594.9s)


[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[30/500] Saved /content/drive/MyDrive/diffusion_generated_images_latest/synthetic_0030.png  (avg 29.75s/img, elapsed 892.5s)


[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[40/500] Saved /content/drive/MyDrive/diffusion_generated_images_latest/synthetic_0040.png  (avg 29.76s/img, elapsed 1190.4s)


[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]

[Removed a Jupyter widget output for compatibility]