# Flux Microscopy Image Enhancement Demo

This notebook uses the Flux Kontext model to intelligently enhance cellular microscopy images, improving image clarity and contrast.

In [None]:
# Import necessary libraries
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
import matplotlib.pyplot as plt
import os
import random
import glob

# Load Flux Kontext model
print("Loading Flux Kontext model...")
pipe = FluxKontextPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
)
pipe.to("cuda")
print("✅ Model loaded successfully")

In [None]:
# Image selection and parameter configuration

def get_random_images(base_path, num_images):
    """Randomly select images from specified directory"""
    image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.tif", "*.tiff"]
    all_images = []
    
    for ext in image_extensions:
        all_images.extend(glob.glob(os.path.join(base_path, ext)))
    
    if len(all_images) < num_images:
        print(f"Warning: Only {len(all_images)} images found in {base_path}, less than requested {num_images} images")
        return all_images
    
    return random.sample(all_images, num_images)

# Set paths and parameters
base_dir = "../cell_datasets/confocal_new"
trainA_path = os.path.join(base_dir, "trainA")
trainB_path = os.path.join(base_dir, "trainB")
num_images = 5

# Check directories and select images
if not os.path.exists(trainA_path):
    raise FileNotFoundError(f"Directory does not exist: {trainA_path}")
if not os.path.exists(trainB_path):
    raise FileNotFoundError(f"Directory does not exist: {trainB_path}")

trainA_images = get_random_images(trainA_path, num_images)
trainB_images = get_random_images(trainB_path, num_images)
image_paths = trainA_images + trainB_images

print("Selected images:")
for path in image_paths:
    subfolder = os.path.basename(os.path.dirname(path))
    filename = os.path.basename(path)
    print(f"  {subfolder}/{filename}")

# Define enhancement prompts
prompt = "enhance microscopy image clarity and contrast selectively on cellular structures only, sharpen cell boundaries and organelles while preserving smooth background areas, improve definition of membranes and internal structures, increase visibility of fine cellular details, maintain clean background regions, preserve original colors exactly, selective enhancement of biological structures only, professional microscopy enhancement with noise reduction"

negative_prompt = "background noise amplification, grain enhancement, speckle artifacts, noise boost, texture noise, random pixel variation, color distortion, color shift, oversaturation, artificial coloring, excessive noise, over-processing artifacts, loss of cellular structure, blurred details, plastic appearance, unnatural smoothing, halo effects, ringing artifacts, background texture enhancement"

print(f"\n✅ Configuration completed, will process {len(image_paths)} images")

In [None]:
# Batch image enhancement processing

results = []

print("Starting image enhancement processing...")
for i, img_path in enumerate(image_paths, 1):
    input_image = load_image(img_path)
    
    # Use Flux for image enhancement
    enhanced_image = pipe(
        image=input_image,
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=2.5,
        num_inference_steps=40,
    ).images[0]
    
    # Save results
    subfolder = os.path.basename(os.path.dirname(img_path))
    filename = os.path.basename(img_path)
    label = f"{subfolder}/{filename}"
    
    results.append((input_image, enhanced_image, label))
    print(f"✅ [{i}/{len(image_paths)}] Enhancement completed: {label}")

print(f"\n🎉 All image enhancement completed! Processed {len(results)} images in total")

In [None]:
# Compare and display enhancement results

if not results:
    print("⚠️ No processing results found, please run the image enhancement cell above first")
else:
    rows = len(results)
    plt.figure(figsize=(12, 3 * rows))
    
    for i, (orig, enhanced, label) in enumerate(results):
        # Original image
        plt.subplot(rows, 2, 2 * i + 1)
        plt.title(f"Original - {label}", fontsize=10)
        plt.imshow(orig)
        plt.axis("off")
        
        # Enhanced image
        plt.subplot(rows, 2, 2 * i + 2)
        plt.title(f"Enhanced - {label}", fontsize=10)
        plt.imshow(enhanced)
        plt.axis("off")
    
    plt.tight_layout()
    plt.show()
    print(f"📊 Display completed, compared {len(results)} image pairs")