In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
base_path = "/content/drive/MyDrive"
dataset_dir = os.path.join(base_path, "dataset furniture")

filenames = os.listdir(dataset_dir)
print(f"📁 Total files in folder: {len(filenames)}")

In [None]:
drawn_ids = set()
original_ids = set()

for fname in filenames:
    if "_drawn" in fname:
        image_id = fname.replace("fimage_", "").split("_drawn")[0]
        drawn_ids.add(image_id)

    elif "_original" in fname:
        image_id = fname.replace("fimage_", "").split("_original")[0]
        original_ids.add(image_id)

paired_ids = drawn_ids & original_ids

In [None]:
print(f"🖌️ Total sketch images: {len(drawn_ids)}")
print(f"🖼️ Total original images: {len(original_ids)}")
print(f"✅ Found {len(paired_ids)} sketch-original image pairs.")

#### Train Test split

In [None]:
# Step 1: Index all files by ID and type (drawn / original)
file_mapping = {}

for fname in filenames:
    if not fname.startswith("fimage"):
        continue

    if "_drawn" in fname:
        image_id = fname.replace("fimage", "").split("_drawn")[0]
        file_mapping.setdefault(image_id, {})['drawn'] = fname
    elif "_original" in fname:
        image_id = fname.replace("fimage", "").split("_original")[0]
        file_mapping.setdefault(image_id, {})['original'] = fname

# Step 2: Filter to only include complete pairs
paired_ids = sorted([img_id for img_id, pair in file_mapping.items() if 'drawn' in pair and 'original' in pair])

# Step 3: Train/test split
split_idx = int(0.7 * len(paired_ids))
train_ids = paired_ids[:split_idx]
test_ids = paired_ids[split_idx:]

# Step 4: Build train and test filenames
train_filenames = []
test_filenames = []

for image_id in train_ids:
    pair = file_mapping[image_id]
    train_filenames.append(pair['drawn'])
    train_filenames.append(pair['original'])

for image_id in test_ids:
    pair = file_mapping[image_id]
    test_filenames.append(pair['drawn'])
    test_filenames.append(pair['original'])

print(f"🛤️ Train set size: {len(train_filenames)} images")
print(f"🧪 Test set size: {len(test_filenames)} images")

In [None]:
import os
import shutil

# Define source and destination directories
train_dir = os.path.join(base_path, "dataset furniture", "train")
test_dir = os.path.join(base_path, "dataset furniture", "test")
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# Function to copy files if not already present
def copy_files(file_list, src_dir, dst_dir):
    for fname in file_list:
        src_path = os.path.join(src_dir, fname)
        dst_path = os.path.join(dst_dir, fname)

        # Only copy if the file doesn't exist in the destination
        if not os.path.exists(dst_path):
            shutil.copyfile(src_path, dst_path)
            print(f"Copied: {fname}")  # You can remove this line if not needed for debugging
        else:
            print(f"Skipped (already exists): {fname}")  # Optional, for feedback

# Copy train files only if they don't exist
copy_files(train_filenames, dataset_dir, train_dir)

# Copy test files only if they don't exist
copy_files(test_filenames, dataset_dir, test_dir)

print("✅ Files successfully copied to train/ and test/ folders.")

#### Pre Processing

###### Resizing

In [None]:
from PIL import Image

# Function to get the size of an image
def get_image_size(image_path):
    with Image.open(image_path) as img:
        return img.size  # Returns (width, height)

# Checking the size of a sample image in the dataset
sample_image_path = os.path.join(dataset_dir, train_filenames[0])  # Use any file in the dataset
image_size = get_image_size(sample_image_path)
print(f"Sample image size: {image_size}")

##### Normalization

In [None]:
import numpy as np
from PIL import Image
import os
import random

# Create a folder to save normalized images
normalized_train_dir = os.path.join(dataset_dir, "normalized_train")
normalized_test_dir = os.path.join(dataset_dir, "normalized_test")

os.makedirs(normalized_train_dir, exist_ok=True)
os.makedirs(normalized_test_dir, exist_ok=True)

# Function to normalize image and save it
def normalize_image(image_path, save_path):
    with Image.open(image_path) as img:
        img_array = np.array(img).astype(np.float32)
        img_array = (img_array / 127.5) - 1.0
        normalized_img = Image.fromarray(((img_array + 1.0) * 127.5).astype(np.uint8))
        normalized_img.save(save_path)
        print(f"Saved normalized image to {save_path}")

# Normalize and save all images in the training and test sets
# for fname in train_filenames:
#     image_path = os.path.join(train_dir, fname)
#     save_path = os.path.join(normalized_train_dir, fname)
#     normalize_image(image_path, save_path)

# for fname in test_filenames:
#     image_path = os.path.join(test_dir, fname)
#     save_path = os.path.join(normalized_test_dir, fname)
#     normalize_image(image_path, save_path)

def check_random_normalized_images(file_list, normalized_dir, num_samples=2):
    samples = random.sample(file_list, min(num_samples, len(file_list)))
    all_correct = True

    for fname in samples:
        normalized_path = os.path.join(normalized_dir, fname)
        with Image.open(normalized_path) as img:
            img_array = np.array(img)
            if img_array.max() > 255 or img_array.min() < 0:
                print(f"❌ Image {fname} has incorrect normalization values!")
                all_correct = False
            else:
                print(f"✅ Image {fname} appears correctly normalized.")

    if all_correct:
        print("✅ All images normalized and saved successfully!")
    else:
        print(f"⚠️ Some images in {normalized_dir} may not be normalized correctly.")

check_random_normalized_images(train_filenames, normalized_train_dir)
check_random_normalized_images(test_filenames, normalized_test_dir)

In [None]:
from sklearn.model_selection import train_test_split

# Step 1: Extract training image IDs (correctly)
train_ids = [fname.replace("fimage", "").split("_drawn")[0] for fname in train_filenames if "_drawn" in fname]

# Step 2: Split image IDs into train/val
train_ids_split, val_ids_split = train_test_split(train_ids, test_size=0.2, random_state=42)

# Step 3: Group filenames again by ID
val_filenames = []
new_train_filenames = []

for image_id in train_ids_split:
    if image_id in file_mapping:
        pair = file_mapping[image_id]
        new_train_filenames.append(pair['drawn'])
        new_train_filenames.append(pair['original'])

for image_id in val_ids_split:
    if image_id in file_mapping:
        pair = file_mapping[image_id]
        val_filenames.append(pair['drawn'])
        val_filenames.append(pair['original'])

# ✅ Summary
print(f"🛤️ Train set: {len(new_train_filenames)} images")
print(f"🧪 Validation set: {len(val_filenames)} images")
print(f"🧪 Test set (unchanged): {len(test_filenames)} images")

In [None]:
import os
import shutil

# Define path for validation directory
val_dir = os.path.join(dataset_dir, "normalized_val")
os.makedirs(val_dir, exist_ok=True)

# Function to copy validation images
def copy_files_if_needed(file_list, source_dir, destination_dir, expected_file_count):
    # Check if the validation directory already has the expected number of files
    existing_files = len([f for f in os.listdir(destination_dir) if os.path.isfile(os.path.join(destination_dir, f))])

    if existing_files >= expected_file_count:
        print(f"⏩ Skipped copying: {existing_files} files already exist in {destination_dir}.")
        return  # Exit if the expected number of files are already present

    for fname in file_list:
        src_path = os.path.join(source_dir, fname)
        dst_path = os.path.join(destination_dir, fname)

        if not os.path.exists(src_path):
            print(f"⚠️ Warning: Source file not found: {src_path}")
        elif os.path.exists(dst_path):
            print(f"⏩ Skipped (already exists): {fname}")
        else:
            shutil.copy(src_path, dst_path)
            print(f"✅ Copied: {fname}")

# Expected number of files in validation set
expected_val_file_count = 428

# Copy validation files
copy_files_if_needed(val_filenames, train_dir, val_dir, expected_val_file_count)

print("✅ Validation files checked and copied if needed.")

#### Data Preparation for training

In [None]:
import os
import shutil

# Function to count the files in the subfolders
def count_files_in_subfolders(base_dir):
    original_dir = os.path.join(dataset_dir, "original")
    drawn_dir = os.path.join(dataset_dir, "drawn")

    num_original_files = len(os.listdir(original_dir)) if os.path.exists(original_dir) else 0
    num_drawn_files = len(os.listdir(drawn_dir)) if os.path.exists(drawn_dir) else 0

    return num_original_files, num_drawn_files

# Function to create subfolders and copy files (if not already copied)
def create_subfolders_and_copy_files(base_dir, expected_original_files, expected_drawn_files):
    original_dir = os.path.join(base_dir, "original")
    drawn_dir = os.path.join(base_dir, "drawn")

    os.makedirs(original_dir, exist_ok=True)
    os.makedirs(drawn_dir, exist_ok=True)

    num_original_files, num_drawn_files = count_files_in_subfolders(base_dir)

    if num_original_files >= expected_original_files and num_drawn_files >= expected_drawn_files:
        print(f"⏩ Skipping folder: {base_dir}, already contains enough original and drawn images.")
        return

    # Loop through all files in the base directory (not subfolders)
    for fname in os.listdir(base_dir):
        file_path = os.path.join(base_dir, fname)

        if not os.path.isfile(file_path):
            continue  # Skip if not a file

        # Process drawn images
        if "_drawn" in fname:
            image_id = fname.replace("image_", "").split("_drawn")[0]
            original_fname = f"image_{image_id}_original.png"

            # Copy drawn image if not already copied
            dst_drawn_path = os.path.join(drawn_dir, fname)
            if not os.path.exists(dst_drawn_path):
                shutil.copy(file_path, dst_drawn_path)
                print(f"✅ Copied: {fname} to {drawn_dir}")

            # Copy corresponding original image if not already copied
            src_original_path = os.path.join(base_dir, original_fname)
            dst_original_path = os.path.join(original_dir, original_fname)
            if os.path.exists(src_original_path) and not os.path.exists(dst_original_path):
                shutil.copy(src_original_path, dst_original_path)
                print(f"✅ Copied: {original_fname} to {original_dir}")

        # Process standalone original images
        elif "_original" in fname:
            dst_original_path = os.path.join(original_dir, fname)
            if not os.path.exists(dst_original_path):
                shutil.copy(file_path, dst_original_path)
                print(f"✅ Copied: {fname} to {original_dir}")

In [None]:
expected_train_original = 853
expected_train_drawn = 853
expected_val_original = 214
expected_val_drawn = 214
expected_test_original = 458
expected_test_drawn = 458

# Calling the function for each set
create_subfolders_and_copy_files(normalized_train_dir, expected_train_original, expected_train_drawn)
create_subfolders_and_copy_files(val_dir, expected_val_original, expected_val_drawn)
create_subfolders_and_copy_files(normalized_test_dir, expected_test_original, expected_test_drawn)

print("✅ All files have been checked and copied if needed.")

In [None]:
def get_image_paths(dir_path):
    return [
        os.path.join(dir_path, fname)
        for fname in sorted(os.listdir(dir_path))
        if fname.lower().endswith((".png", ".jpg", ".jpeg"))  # handles common formats
    ]


train_drawn_paths = get_image_paths(os.path.join(normalized_train_dir, "drawn"))
train_original_paths = get_image_paths(os.path.join(normalized_train_dir, "original"))

val_drawn_paths = get_image_paths(os.path.join(val_dir, "drawn"))
val_original_paths = get_image_paths(os.path.join(val_dir, "original"))

test_drawn_paths = get_image_paths(os.path.join(normalized_test_dir, "drawn"))
test_original_paths = get_image_paths(os.path.join(normalized_test_dir, "original"))

In [None]:
print(f"Train drawn: {len(train_drawn_paths)}")
print(f"Train original: {len(train_original_paths)}")

print(f"Val drawn: {len(val_drawn_paths)}")
print(f"Val original: {len(val_original_paths)}")

print(f"Test drawn: {len(test_drawn_paths)}")
print(f"Test original: {len(test_original_paths)}")

In [None]:
import tensorflow as tf

def load_image_pair(drawn_path, original_path):
    def _decode_and_preprocess(path):
        image = tf.io.read_file(path)
        image = tf.image.decode_png(image, channels=3)
        image = tf.image.convert_image_dtype(image, tf.float32)  # [0,1] instead of casting
        image = tf.image.resize(image, [512,512])
        return image

    drawn_image = _decode_and_preprocess(drawn_path)
    original_image = _decode_and_preprocess(original_path)

    # Apply augmentation with probability
    if tf.random.uniform([]) > 0.5:
        drawn_image = tf.image.flip_left_right(drawn_image)
        original_image = tf.image.flip_left_right(original_image)

    if tf.random.uniform([]) > 0.7:
        drawn_image = tf.image.random_brightness(drawn_image, max_delta=0.1)
        drawn_image = tf.image.random_contrast(drawn_image, lower=0.9, upper=1.1)

    # Normalize to [-1, 1]
    drawn_image = (drawn_image * 2.0) - 1.0
    original_image = (original_image * 2.0) - 1.0

    return drawn_image, original_image

In [None]:
def augment_data(input_image, target_image):
    # Random jitter (resize and crop)
    stacked = tf.concat([input_image, target_image], axis=2)

    # Random flip
    if tf.random.uniform(()) > 0.5:
        stacked = tf.image.flip_left_right(stacked)

    # Split back
    input_image = stacked[:, :, :3]
    target_image = stacked[:, :, 3:]

    # Adjust brightness/contrast for sketch only
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.random_brightness(input_image, 0.2)

    # Optional: random cropping with resize
    if tf.random.uniform(()) > 0.5:
        # Get dimensions
        height = tf.shape(input_image)[0]
        width = tf.shape(input_image)[1]

        # Random crop size (80-100% of original size)
        crop_size = tf.random.uniform(
            [],
            minval=tf.cast(0.8 * tf.cast(tf.minimum(height, width), tf.float32), tf.int32),
            maxval=tf.cast(tf.minimum(height, width), tf.int32),
            dtype=tf.int32
        )

        # Apply the same crop to both images
        stacked = tf.concat([input_image, target_image], axis=2)
        cropped = tf.image.random_crop(stacked, [crop_size, crop_size, 6])

        # Split and resize back to original dimensions
        input_image = tf.image.resize(cropped[:, :, :3], [256, 256])
        target_image = tf.image.resize(cropped[:, :, 3:], [256, 256])

    return input_image, target_image

In [None]:
# Define constants
BATCH_SIZE = 8
SHUFFLE_BUFFER = 1000

# Function to create dataset from file paths
def build_dataset(drawn_paths, original_paths, training=True):
    # Ensure paths are in list of strings
    drawn_paths = [str(p) for p in drawn_paths]
    original_paths = [str(p) for p in original_paths]

    # Ensure the paths are the same length
    assert len(drawn_paths) == len(original_paths), "Drawn and original paths must have the same length"

    # Create tf.data.Dataset from the file paths
    dataset = tf.data.Dataset.from_tensor_slices((drawn_paths, original_paths))

    # Map the image loading function
    dataset = dataset.map(load_image_pair, num_parallel_calls=tf.data.AUTOTUNE)

    # Shuffle the dataset only for training
    if training:
        dataset = dataset.shuffle(SHUFFLE_BUFFER)

    # Batch the dataset
    dataset = dataset.batch(BATCH_SIZE)

    # Prefetch for performance
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset

In [None]:
# Build datasets for train, validation, and test
train_dataset = build_dataset(train_drawn_paths, train_original_paths, training=True)
val_dataset = build_dataset(val_drawn_paths, val_original_paths, training=False)
test_dataset = build_dataset(test_drawn_paths, test_original_paths, training=False)

# Example usage: Iterate through the train dataset
for images in train_dataset.take(1):
    drawn_images, original_images = images
    print(f"Drawn image batch shape: {drawn_images.shape}")
    print(f"Original image batch shape: {original_images.shape}")

print("✅ Datasets created for train, validation, and test.")

#### Model Defined

In [None]:
import os
import warnings
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from diffusers import (
    ControlNetModel,
    UNet2DConditionModel,
    DDPMScheduler,
    StableDiffusionControlNetPipeline,
    get_scheduler
)
from transformers import CLIPTokenizer, CLIPTextModel
from accelerate import Accelerator

In [None]:
# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configuration with improved settings
class Config:
    pretrained_model_name = "runwayml/stable-diffusion-v1-5"  # Base Stable Diffusion model
    output_dir = "sketch_to_image_model"
    resolution = 512
    train_batch_size = 8
    val_batch_size = 8
    num_train_epochs = 20
    gradient_accumulation_steps = 2
    learning_rate = 1e-4
    lr_scheduler = "constant"
    lr_warmup_steps = 0
    mixed_precision = "no"
    save_images_epochs = 1
    save_model_epochs = 1
    validation_steps = 50

config = Config()

In [None]:
def add_error_handling():
    """
    Add custom error handling function to diagnose common issues
    """
    def handle_tensor_error(func):
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except RuntimeError as e:
                if "expected scalar type" in str(e) or "type" in str(e):
                    print(f"Data type mismatch error detected: {e}")
                    print("Input tensor types:")
                    for i, arg in enumerate(args):
                        if isinstance(arg, torch.Tensor):
                            print(f"Arg {i}: {arg.dtype}")
                    for k, v in kwargs.items():
                        if isinstance(v, torch.Tensor):
                            print(f"Kwarg {k}: {v.dtype}")
                    raise
                else:
                    raise
        return wrapper

    # Patch key functions to add error handling
    original_conv2d = torch.nn.functional.conv2d
    torch.nn.functional.conv2d = handle_tensor_error(original_conv2d)

In [None]:
# Custom Dataset with improved prompts and consistent preprocessing
class SketchToImageDataset(Dataset):
    def __init__(self, sketch_paths, image_paths, tokenizer, prompt_engineering=True):
        self.sketch_paths = sketch_paths
        self.image_paths = image_paths
        self.tokenizer = tokenizer
        self.prompt_engineering = prompt_engineering

        # Consistent transform for both sketch and image
        self.transform = transforms.Compose([
            transforms.Resize((config.resolution, config.resolution), interpolation=transforms.InterpolationMode.LANCZOS),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
        ])

        # Furniture-specific prompt collection for variety
        self.furniture_prompts = [
            "a detailed, high-quality photograph of furniture",
            "a professional photograph of home furniture in natural lighting",
            "a realistic photograph of furniture with clean details",
            "a high-resolution image of furniture from sketch",
            "a photorealistic rendering of furniture design",
            "a detailed furniture photograph with accurate textures and materials",
            "a professional product photograph of furniture piece",
            "a clear, detailed image of furniture with realistic details",
        ]

        self.default_prompt = "a detailed, high-quality photograph of furniture generated from sketch"

    def __len__(self):
        return len(self.sketch_paths)

    def __getitem__(self, idx):
        # Load sketch and target image
        sketch_path = self.sketch_paths[idx]
        image_path = self.image_paths[idx]

        try:
            sketch = Image.open(sketch_path).convert("RGB")
            target_image = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"Error loading images: {e}")
            print(f"Sketch path: {sketch_path}")
            print(f"Image path: {image_path}")
            # Return a default item if there's an error
            sketch = Image.new("RGB", (config.resolution, config.resolution), color="white")
            target_image = Image.new("RGB", (config.resolution, config.resolution), color="white")

        # Apply transformations
        sketch_tensor = self.transform(sketch)
        target_tensor = self.transform(target_image)

        # Select a prompt (either fixed or random from collection)
        if self.prompt_engineering:
            prompt = np.random.choice(self.furniture_prompts)
        else:
            prompt = self.default_prompt

        # Encode the text prompt
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        )
        text_input_ids = text_inputs.input_ids[0]

        return {
            "sketch": sketch_tensor,
            "target": target_tensor,
            "input_ids": text_input_ids,
            "file_path": image_path
        }

# Helper function to create directories if they don't exist
def create_directories():
    os.makedirs(config.output_dir, exist_ok=True)
    os.makedirs(os.path.join(config.output_dir, "samples"), exist_ok=True)
    os.makedirs(os.path.join(config.output_dir, "checkpoints"), exist_ok=True)

In [None]:
# Function to save sample images during training
def save_samples(controlnet, unet, vae, text_encoder, tokenizer, noise_scheduler, sketch_batch, epoch, device, show_images=True):
    # Create pipeline for inference
    pipeline = StableDiffusionControlNetPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        controlnet=controlnet,
        scheduler=noise_scheduler,
        safety_checker=None,
        feature_extractor=None,
        requires_safety_checker=False
    )

    # Move to device
    pipeline = pipeline.to(device)

    # Get a few sketch samples
    sketches = sketch_batch["sketch"].to(device)
    targets = sketch_batch["target"].to(device)
    num_samples = min(4, len(sketches))

    sketches = sketches[:num_samples]
    targets = targets[:num_samples]

    prompt = ["a detailed, high-quality photograph generated from a sketch"] * num_samples

    # Generate images
    with torch.no_grad():
        images = pipeline(
            prompt=prompt,
            image=sketches,
            num_inference_steps=100,
            guidance_scale=10.0
        ).images

    # Convert tensors and images
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

    for i in range(num_samples):
        sketch_img = (sketches[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5) * 255
        sketch_img = sketch_img.clip(0, 255).astype(np.uint8)

        target_img = (targets[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5) * 255
        target_img = target_img.clip(0, 255).astype(np.uint8)

        if num_samples > 1:
            axes[i, 0].imshow(sketch_img)
            axes[i, 0].set_title("Sketch Input")
            axes[i, 0].axis("off")

            axes[i, 1].imshow(images[i])
            axes[i, 1].set_title("Generated Image")
            axes[i, 1].axis("off")

            axes[i, 2].imshow(target_img)
            axes[i, 2].set_title("Target Image")
            axes[i, 2].axis("off")
        else:
            axes[0].imshow(sketch_img)
            axes[0].set_title("Sketch Input")
            axes[0].axis("off")

            axes[1].imshow(images[i])
            axes[1].set_title("Generated Image")
            axes[1].axis("off")

            axes[2].imshow(target_img)
            axes[2].set_title("Target Image")
            axes[2].axis("off")

    plt.tight_layout()
    plt.savefig(os.path.join(config.output_dir, "samples", f"sample_epoch_{epoch}.png"))

    if show_images:
        plt.show()  # <-- Show directly in notebook or terminal with GUI support

    plt.close()

    return images

In [None]:
def train_controlnet(train_sketch_paths, train_image_paths, val_sketch_paths, val_image_paths):
    create_directories()

    add_error_handling()

    try:
        import gc
        # Force garbage collection to clean up any existing accelerator
        gc.collect()
        try:
            from accelerate.state import AcceleratorState
            if hasattr(AcceleratorState, '_state'):
                if AcceleratorState._state is not None:
                    AcceleratorState._state = None
            elif hasattr(AcceleratorState, 'reset_state'):
                AcceleratorState.reset_state()
        except (ImportError, AttributeError):
            pass

        import os
        os.environ["ACCELERATE_STATE_INITIALIZED"] = "0"
    except Exception as e:
        print(f"Note: Could not reset accelerator state: {e}")
        print("Proceeding with initialization anyway...")

    # Initialize accelerator with fallback options
    try:
        accelerator = Accelerator(
            mixed_precision="no",
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            log_with="tensorboard",
            project_dir=os.path.join(config.output_dir, "logs")
        )

        # Initialize logging
        accelerator.init_trackers("controlnet_training")
    except TypeError as e:
        print(f"Warning: Error initializing accelerator with full config: {e}")
        print("Trying with simplified configuration...")
        accelerator = Accelerator(
            mixed_precision="no",
            gradient_accumulation_steps=config.gradient_accumulation_steps
        )
        print("Accelerator initialized with basic configuration.")

    # Load the tokenizer
    tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_name, subfolder="tokenizer")

    # Create datasets
    train_dataset = SketchToImageDataset(train_sketch_paths, train_image_paths, tokenizer)
    val_dataset = SketchToImageDataset(val_sketch_paths, val_image_paths, tokenizer)

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=4
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config.val_batch_size,
        shuffle=False,
        num_workers=4
    )

    # Load models
    noise_scheduler = DDPMScheduler.from_pretrained(config.pretrained_model_name, subfolder="scheduler")

    # Load the VAE component
    from diffusers import AutoencoderKL
    vae = AutoencoderKL.from_pretrained(config.pretrained_model_name, subfolder="vae")

    text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_name, subfolder="text_encoder")

    # Load the UNet
    unet = UNet2DConditionModel.from_pretrained(config.pretrained_model_name, subfolder="unet")

    # Create a ControlNet model from the UNet
    controlnet = ControlNetModel.from_unet(unet)

    # Freeze vae and text_encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)  # We only train ControlNet

    # Optimizer for ControlNet with weight decay
    optimizer = torch.optim.AdamW(
        controlnet.parameters(),
        lr=config.learning_rate,
        weight_decay=1e-2  # Added weight decay to improve generalization
    )

    # Calculate total number of training steps
    total_train_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
    num_update_steps_per_epoch = len(train_dataloader) // config.gradient_accumulation_steps
    max_train_steps = config.num_train_epochs * num_update_steps_per_epoch

    # Learning rate scheduler
    lr_scheduler = get_scheduler(
        name=config.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
        num_training_steps=max_train_steps,
    )

    # Prepare everything with accelerator with fallback
    try:
        controlnet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
            controlnet, optimizer, train_dataloader, val_dataloader, lr_scheduler
        )
    except Exception as e:
        print(f"Warning: Error in accelerator.prepare: {e}")
        print("Using manual device management instead.")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        controlnet.to(device)

        # Define a simple wrapper class for backward compatibility
        class SimpleAccelerator:
            def __init__(self, device):
                self.device = device
                self.sync_gradients = True
                self.num_processes = 1

            def backward(self, loss):
                loss.backward()

            def clip_grad_norm_(self, params, max_norm):
                torch.nn.utils.clip_grad_norm_(params, max_norm)

            def save_model(self, model, output_dir):
                model.save_pretrained(output_dir)

            def print(self, *args, **kwargs):
                print(*args, **kwargs)

            def accumulate(self, model):
                class NoOpContextManager:
                    def __enter__(self): return None
                    def __exit__(self, *args): return None
                return NoOpContextManager()

            def end_training(self):
                pass

        if 'accelerator' not in locals() or not hasattr(accelerator, 'backward'):
            print("Creating simple accelerator replacement")
            accelerator = SimpleAccelerator(device)

    # Set up device and precision
    weight_dtype = torch.float32

    # For mixed precision training, we cast the models to the appropriate precision
    if hasattr(accelerator, 'mixed_precision'):
        if accelerator.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif accelerator.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16

    # Determine device - use accelerator's device or fall back to CUDA/CPU
    device = accelerator.device if hasattr(accelerator, 'device') else None
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Manually selected device: {device}")

    # Cast all models to the correct dtype and device
    text_encoder.to(device, dtype=weight_dtype)
    vae.to(device, dtype=weight_dtype)
    unet.to(device, dtype=weight_dtype)
    controlnet.to(device, dtype=weight_dtype)

    # We need to keep vae, unet and text_encoder in eval mode
    vae.eval()
    text_encoder.eval()
    unet.eval()

    # Set controlnet to train mode
    controlnet.train()

    # Keep track of losses
    global_step = 0
    best_loss = float('inf')
    best_model_path = None

    for epoch in range(config.num_train_epochs):
        controlnet.train()
        total_loss = 0

        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(controlnet):
                # Move inputs to device
                sketch = batch["sketch"].to(accelerator.device, dtype=weight_dtype)
                target = batch["target"].to(accelerator.device, dtype=weight_dtype)
                input_ids = batch["input_ids"].to(accelerator.device)

                # Encode target image to latent space
                latents = vae.encode(target).latent_dist.sample()
                latents = latents * 0.18215

                # Sample noise and add to latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Encode text
                encoder_hidden_states = text_encoder(input_ids)[0]

                # Predict noise with ControlNet conditioning - with flexible output handling
                try:
                    controlnet_output = controlnet(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states,
                        sketch,
                        return_dict=True
                    )

                    # Check if return_dict worked or if we got a tuple/list
                    if hasattr(controlnet_output, 'down_block_res_samples'):
                        down_block_res_samples = controlnet_output.down_block_res_samples
                        mid_block_res_sample = controlnet_output.mid_block_res_sample
                    else:
                        # Handle tuple output format for older versions
                        print("Handling tuple output format from ControlNet")
                        down_block_res_samples = controlnet_output[0]
                        mid_block_res_sample = controlnet_output[1]

                except Exception as e:
                    print(f"Error in ControlNet forward pass: {e}")
                    print("Trying alternative calling convention...")
                    controlnet_output = controlnet(
                        noisy_latents,
                        timesteps,
                        conditioning=sketch,
                        encoder_hidden_states=encoder_hidden_states
                    )
                    # Extract components based on output type
                    if isinstance(controlnet_output, tuple):
                        down_block_res_samples = controlnet_output[0]
                        mid_block_res_sample = controlnet_output[1]
                    else:
                        down_block_res_samples = controlnet_output.down_block_res_samples
                        mid_block_res_sample = controlnet_output.mid_block_res_sample

                # Get UNet prediction with ControlNet residuals - with flexible output handling
                try:
                    unet_output = unet(
                        sample=noisy_latents,
                        timestep=timesteps,
                        encoder_hidden_states=encoder_hidden_states,
                        down_block_additional_residuals=down_block_res_samples,
                        mid_block_additional_residual=mid_block_res_sample,
                        return_dict=True
                    )

                    # Extract prediction based on output type
                    if hasattr(unet_output, 'sample'):
                        model_pred = unet_output.sample
                    else:
                        model_pred = unet_output[0]

                except Exception as e:
                    print(f"Error in UNet forward pass: {e}")
                    print("Trying alternative calling convention...")
                    unet_output = unet(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states=encoder_hidden_states
                    )
                    # Handle output format
                    if isinstance(unet_output, tuple):
                        model_pred = unet_output[0]
                    else:
                        model_pred = unet_output.sample

                # Compute loss
                loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(controlnet.parameters(), 1.0)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

            global_step += 1
            total_loss += loss.item()

            if global_step % config.validation_steps == 0:
                accelerator.print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}")

            if global_step % 100 == 0:  # Every a100 steps, log the loss
                print(f"Epoch [{epoch + 1}/{config.num_train_epochs}], Step [{global_step}], Loss: {total_loss / (step + 1)}")

        avg_loss = total_loss / len(train_dataloader)
        accelerator.print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}")

        # Save the best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_model_path = os.path.join(config.output_dir, "checkpoints", f"best_model_epoch_{epoch}")
            accelerator.save_model(controlnet, best_model_path)

            # Save the entire model with its configuration
            controlnet.save_pretrained(best_model_path)  # This will create the config.json

            print(f"New best model saved at epoch {epoch} with loss {best_loss:.4f}")


        # Save images and model
        if epoch % config.save_images_epochs == 0:
            val_batch = next(iter(val_dataloader))
            save_samples(controlnet, unet, vae, text_encoder, tokenizer, noise_scheduler, val_batch, epoch, accelerator.device)

        if epoch % config.save_model_epochs == 0:
            model_path = os.path.join(config.output_dir, "checkpoints", f"controlnet_epoch_{epoch}")
            accelerator.save_model(controlnet, model_path)

    accelerator.end_training()

    # Return the path to the best model
    if best_model_path is None:
        best_model_path = os.path.join(config.output_dir, "checkpoints", f"controlnet_epoch_{config.num_train_epochs-1}")

    return best_model_path

In [None]:
def run_inference(model_path, test_sketch_paths, test_image_paths, num_samples=5):
    import os
    print(f"Running inference with model from: {model_path}")

    # Load the trained ControlNet model
    try:
        controlnet = ControlNetModel.from_pretrained(model_path)
    except Exception as e:
        print(f"Error loading ControlNet model: {e}")
        print("Attempting to load with a different method...")
        try:
            from diffusers import ControlNetModel
            controlnet = ControlNetModel.from_pretrained(model_path, local_files_only=True)
        except Exception as sub_e:
            print(f"Second loading attempt failed: {sub_e}")
            print("Checking directory contents:")
            import os
            print(os.listdir(model_path))
            raise RuntimeError(f"Could not load model from {model_path}")

    # Create the pipeline with device detection
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    try:
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            config.pretrained_model_name,
            controlnet=controlnet,
            safety_checker=None,
            requires_safety_checker=False
        ).to(device)
    except Exception as e:
        print(f"Error creating pipeline: {e}")
        print("Trying alternative pipeline configuration...")
        # Fallback with explicit component loading
        from diffusers import (
            AutoencoderKL,
            DDIMScheduler,
            StableDiffusionControlNetPipeline,
            UNet2DConditionModel
        )
        from transformers import CLIPTextModel, CLIPTokenizer

        # Load individual components
        vae = AutoencoderKL.from_pretrained(
            config.pretrained_model_name, subfolder="vae"
        )
        unet = UNet2DConditionModel.from_pretrained(
            config.pretrained_model_name, subfolder="unet"
        )
        tokenizer = CLIPTokenizer.from_pretrained(
            config.pretrained_model_name, subfolder="tokenizer"
        )
        text_encoder = CLIPTextModel.from_pretrained(
            config.pretrained_model_name, subfolder="text_encoder"
        )
        scheduler = DDIMScheduler.from_pretrained(
            config.pretrained_model_name, subfolder="scheduler"
        )

        # Create pipeline with explicit components
        pipe = StableDiffusionControlNetPipeline(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            controlnet=controlnet,
            scheduler=scheduler,
            safety_checker=None,
            feature_extractor=None,
            requires_safety_checker=False
        ).to(device)

    # Set the pipeline to use deterministic generation
    pipe.scheduler = DDPMScheduler.from_pretrained(config.pretrained_model_name, subfolder="scheduler")

    # Create a directory for the results
    results_dir = "inference_results"
    os.makedirs(results_dir, exist_ok=True)

    # Randomly sample test sketches
    indices = np.random.choice(len(test_sketch_paths), min(num_samples, len(test_sketch_paths)), replace=False)

    # Transform for input sketches
    transform = transforms.Compose([
        transforms.Resize((config.resolution, config.resolution)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    # Generate images for each sample
    for i, idx in enumerate(indices):
        # Load sketch and target image
        sketch_path = test_sketch_paths[idx]
        target_path = test_image_paths[idx]

        print(f"Processing sketch: {sketch_path}")
        print(f"Target image: {target_path}")

        # Load sketch and target image
        sketch = Image.open(sketch_path).convert("RGB")
        target = Image.open(target_path).convert("RGB")

        # Resize both for visualization
        sketch = sketch.resize((config.resolution, config.resolution))
        target = target.resize((config.resolution, config.resolution))

        # Generate image
        prompt = "a detailed, high-quality photograph generated from a sketch"
        image = pipe(
            prompt,
            image=sketch,
            num_inference_steps=75,  # More steps for better quality
            guidance_scale=7.5
        ).images[0]

        # Save individual images for comparison
        sketch.save(os.path.join(results_dir, f"sample_{i}_sketch.png"))
        image.save(os.path.join(results_dir, f"sample_{i}_generated.png"))
        target.save(os.path.join(results_dir, f"sample_{i}_target.png"))

        # Save the results
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        axes[0].imshow(sketch)
        axes[0].set_title("Input Sketch")
        axes[0].axis("off")

        axes[1].imshow(image)
        axes[1].set_title("Generated Image")
        axes[1].axis("off")

        axes[2].imshow(target)
        axes[2].set_title("Ground Truth")
        axes[2].axis("off")

        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, f"comparison_{i}.png"))
        plt.close()

        print(f"Saved comparison image: comparison_{i}.png")

    # Create a summary image with all comparisons
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

    for i in range(len(indices)):
        sketch = Image.open(os.path.join(results_dir, f"sample_{i}_sketch.png"))
        generated = Image.open(os.path.join(results_dir, f"sample_{i}_generated.png"))
        target = Image.open(os.path.join(results_dir, f"sample_{i}_target.png"))

        if num_samples > 1:
            axes[i, 0].imshow(sketch)
            axes[i, 0].set_title(f"Sketch {i+1}")
            axes[i, 0].axis("off")

            axes[i, 1].imshow(generated)
            axes[i, 1].set_title(f"Generated {i+1}")
            axes[i, 1].axis("off")

            axes[i, 2].imshow(target)
            axes[i, 2].set_title(f"Target {i+1}")
            axes[i, 2].axis("off")
        else:
            axes[0].imshow(sketch)
            axes[0].set_title("Sketch")
            axes[0].axis("off")

            axes[1].imshow(generated)
            axes[1].set_title("Generated")
            axes[1].axis("off")

            axes[2].imshow(target)
            axes[2].set_title("Target")
            axes[2].axis("off")

    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, "all_comparisons.png"))
    plt.close()

    print(f"All inference results saved to: {results_dir}")
    print(f"Summary image saved as: {os.path.join(results_dir, 'all_comparisons.png')}")

    return os.path.join(results_dir, "all_comparisons.png")

In [None]:
if __name__ == "__main__":
    train_sketch_paths = train_drawn_paths
    train_image_paths = train_original_paths
    val_sketch_paths = val_drawn_paths
    val_image_paths = val_original_paths
    test_sketch_paths = test_drawn_paths
    test_image_paths = test_original_paths

    # Train the model
    best_model_path = train_controlnet(train_sketch_paths, train_image_paths, val_sketch_paths, val_image_paths)

    # Run inference with the best model (not just the final model)
    run_inference(
        best_model_path,  # Use the best model instead of final model
        test_sketch_paths,
        test_image_paths,
        num_samples=5
    )