In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
base_path = "/content/drive/MyDrive"
dataset_dir = os.path.join(base_path, "dataset furniture")

filenames = os.listdir(dataset_dir)
print(f"📁 Total files in folder: {len(filenames)}")

In [None]:
drawn_ids = set()
original_ids = set()

for fname in filenames:
    if "_drawn" in fname:
        image_id = fname.replace("fimage_", "").split("_drawn")[0]
        drawn_ids.add(image_id)

    elif "_original" in fname:
        image_id = fname.replace("fimage_", "").split("_original")[0]
        original_ids.add(image_id)

paired_ids = drawn_ids & original_ids

In [None]:
print(f"🖌️ Total sketch images: {len(drawn_ids)}")
print(f"🖼️ Total original images: {len(original_ids)}")
print(f"✅ Found {len(paired_ids)} sketch-original image pairs.")

#### Train Test split

In [None]:
# Step 1: Index all files by ID and type (drawn / original)
file_mapping = {}

for fname in filenames:
    if not fname.startswith("fimage"):
        continue

    if "_drawn" in fname:
        image_id = fname.replace("fimage", "").split("_drawn")[0]
        file_mapping.setdefault(image_id, {})['drawn'] = fname
    elif "_original" in fname:
        image_id = fname.replace("fimage", "").split("_original")[0]
        file_mapping.setdefault(image_id, {})['original'] = fname

# Step 2: Filter to only include complete pairs
paired_ids = sorted([img_id for img_id, pair in file_mapping.items() if 'drawn' in pair and 'original' in pair])

# Step 3: Train/test split
split_idx = int(0.7 * len(paired_ids))
train_ids = paired_ids[:split_idx]
test_ids = paired_ids[split_idx:]

# Step 4: Build train and test filenames
train_filenames = []
test_filenames = []

for image_id in train_ids:
    pair = file_mapping[image_id]
    train_filenames.append(pair['drawn'])
    train_filenames.append(pair['original'])

for image_id in test_ids:
    pair = file_mapping[image_id]
    test_filenames.append(pair['drawn'])
    test_filenames.append(pair['original'])

print(f"🛤️ Train set size: {len(train_filenames)} images")
print(f"🧪 Test set size: {len(test_filenames)} images")

In [None]:
import os
import shutil

# Define source and destination directories
train_dir = os.path.join(base_path, "dataset furniture", "train")
test_dir = os.path.join(base_path, "dataset furniture", "test")
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# Function to copy files if not already present
def copy_files(file_list, src_dir, dst_dir):
    for fname in file_list:
        src_path = os.path.join(src_dir, fname)
        dst_path = os.path.join(dst_dir, fname)

        # Only copy if the file doesn't exist in the destination
        if not os.path.exists(dst_path):
            shutil.copyfile(src_path, dst_path)
            print(f"Copied: {fname}")  # You can remove this line if not needed for debugging
        else:
            print(f"Skipped (already exists): {fname}")  # Optional, for feedback

# Copy train files only if they don't exist
copy_files(train_filenames, dataset_dir, train_dir)

# Copy test files only if they don't exist
copy_files(test_filenames, dataset_dir, test_dir)

print("✅ Files successfully copied to train/ and test/ folders.")

#### Pre Processing

###### Resizing

In [None]:
from PIL import Image

# Function to get the size of an image
def get_image_size(image_path):
    with Image.open(image_path) as img:
        return img.size  # Returns (width, height)

# Checking the size of a sample image in the dataset
sample_image_path = os.path.join(dataset_dir, train_filenames[0])  # Use any file in the dataset
image_size = get_image_size(sample_image_path)
print(f"Sample image size: {image_size}")

##### Normalization

In [None]:
import numpy as np
from PIL import Image
import os
import random

# Create a folder to save normalized images
normalized_train_dir = os.path.join(dataset_dir, "normalized_train")
normalized_test_dir = os.path.join(dataset_dir, "normalized_test")

os.makedirs(normalized_train_dir, exist_ok=True)
os.makedirs(normalized_test_dir, exist_ok=True)

# Function to normalize image and save it
def normalize_image(image_path, save_path):
    with Image.open(image_path) as img:
        img_array = np.array(img).astype(np.float32)
        img_array = (img_array / 127.5) - 1.0
        normalized_img = Image.fromarray(((img_array + 1.0) * 127.5).astype(np.uint8))
        normalized_img.save(save_path)
        print(f"Saved normalized image to {save_path}")

# Normalize and save all images in the training and test sets
# for fname in train_filenames:
#     image_path = os.path.join(train_dir, fname)
#     save_path = os.path.join(normalized_train_dir, fname)
#     normalize_image(image_path, save_path)

# for fname in test_filenames:
#     image_path = os.path.join(test_dir, fname)
#     save_path = os.path.join(normalized_test_dir, fname)
#     normalize_image(image_path, save_path)

def check_random_normalized_images(file_list, normalized_dir, num_samples=2):
    samples = random.sample(file_list, min(num_samples, len(file_list)))
    all_correct = True

    for fname in samples:
        normalized_path = os.path.join(normalized_dir, fname)
        with Image.open(normalized_path) as img:
            img_array = np.array(img)
            if img_array.max() > 255 or img_array.min() < 0:
                print(f"❌ Image {fname} has incorrect normalization values!")
                all_correct = False
            else:
                print(f"✅ Image {fname} appears correctly normalized.")

    if all_correct:
        print("✅ All images normalized and saved successfully!")
    else:
        print(f"⚠️ Some images in {normalized_dir} may not be normalized correctly.")

check_random_normalized_images(train_filenames, normalized_train_dir)
check_random_normalized_images(test_filenames, normalized_test_dir)

In [None]:
from sklearn.model_selection import train_test_split

# Step 1: Extract training image IDs (correctly)
train_ids = [fname.replace("fimage", "").split("_drawn")[0] for fname in train_filenames if "_drawn" in fname]

# Step 2: Split image IDs into train/val
train_ids_split, val_ids_split = train_test_split(train_ids, test_size=0.2, random_state=42)

# Step 3: Group filenames again by ID
val_filenames = []
new_train_filenames = []

for image_id in train_ids_split:
    if image_id in file_mapping:
        pair = file_mapping[image_id]
        new_train_filenames.append(pair['drawn'])
        new_train_filenames.append(pair['original'])

for image_id in val_ids_split:
    if image_id in file_mapping:
        pair = file_mapping[image_id]
        val_filenames.append(pair['drawn'])
        val_filenames.append(pair['original'])

# ✅ Summary
print(f"🛤️ Train set: {len(new_train_filenames)} images")
print(f"🧪 Validation set: {len(val_filenames)} images")
print(f"🧪 Test set (unchanged): {len(test_filenames)} images")

In [None]:
import os
import shutil

# Define path for validation directory
val_dir = os.path.join(dataset_dir, "normalized_val")
os.makedirs(val_dir, exist_ok=True)

# Function to copy validation images
def copy_files_if_needed(file_list, source_dir, destination_dir, expected_file_count):
    # Check if the validation directory already has the expected number of files
    existing_files = len([f for f in os.listdir(destination_dir) if os.path.isfile(os.path.join(destination_dir, f))])

    if existing_files >= expected_file_count:
        print(f"⏩ Skipped copying: {existing_files} files already exist in {destination_dir}.")
        return  # Exit if the expected number of files are already present

    for fname in file_list:
        src_path = os.path.join(source_dir, fname)
        dst_path = os.path.join(destination_dir, fname)

        if not os.path.exists(src_path):
            print(f"⚠️ Warning: Source file not found: {src_path}")
        elif os.path.exists(dst_path):
            print(f"⏩ Skipped (already exists): {fname}")
        else:
            shutil.copy(src_path, dst_path)
            print(f"✅ Copied: {fname}")

# Expected number of files in validation set
expected_val_file_count = 428

# Copy validation files
copy_files_if_needed(val_filenames, train_dir, val_dir, expected_val_file_count)

print("✅ Validation files checked and copied if needed.")

#### Data Preparation for training

In [None]:
import os
import shutil

# Function to count the files in the subfolders
def count_files_in_subfolders(base_dir):
    original_dir = os.path.join(dataset_dir, "original")
    drawn_dir = os.path.join(dataset_dir, "drawn")

    num_original_files = len(os.listdir(original_dir)) if os.path.exists(original_dir) else 0
    num_drawn_files = len(os.listdir(drawn_dir)) if os.path.exists(drawn_dir) else 0

    return num_original_files, num_drawn_files

# Function to create subfolders and copy files (if not already copied)
def create_subfolders_and_copy_files(base_dir, expected_original_files, expected_drawn_files):
    original_dir = os.path.join(base_dir, "original")
    drawn_dir = os.path.join(base_dir, "drawn")

    os.makedirs(original_dir, exist_ok=True)
    os.makedirs(drawn_dir, exist_ok=True)

    num_original_files, num_drawn_files = count_files_in_subfolders(base_dir)

    if num_original_files >= expected_original_files and num_drawn_files >= expected_drawn_files:
        print(f"⏩ Skipping folder: {base_dir}, already contains enough original and drawn images.")
        return

    # Loop through all files in the base directory (not subfolders)
    for fname in os.listdir(base_dir):
        file_path = os.path.join(base_dir, fname)

        if not os.path.isfile(file_path):
            continue  # Skip if not a file

        # Process drawn images
        if "_drawn" in fname:
            image_id = fname.replace("image_", "").split("_drawn")[0]
            original_fname = f"image_{image_id}_original.png"

            # Copy drawn image if not already copied
            dst_drawn_path = os.path.join(drawn_dir, fname)
            if not os.path.exists(dst_drawn_path):
                shutil.copy(file_path, dst_drawn_path)
                print(f"✅ Copied: {fname} to {drawn_dir}")

            # Copy corresponding original image if not already copied
            src_original_path = os.path.join(base_dir, original_fname)
            dst_original_path = os.path.join(original_dir, original_fname)
            if os.path.exists(src_original_path) and not os.path.exists(dst_original_path):
                shutil.copy(src_original_path, dst_original_path)
                print(f"✅ Copied: {original_fname} to {original_dir}")

        # Process standalone original images
        elif "_original" in fname:
            dst_original_path = os.path.join(original_dir, fname)
            if not os.path.exists(dst_original_path):
                shutil.copy(file_path, dst_original_path)
                print(f"✅ Copied: {fname} to {original_dir}")

In [None]:
expected_train_original = 853
expected_train_drawn = 853
expected_val_original = 214
expected_val_drawn = 214
expected_test_original = 458
expected_test_drawn = 458

# Calling the function for each set
create_subfolders_and_copy_files(normalized_train_dir, expected_train_original, expected_train_drawn)
create_subfolders_and_copy_files(val_dir, expected_val_original, expected_val_drawn)
create_subfolders_and_copy_files(normalized_test_dir, expected_test_original, expected_test_drawn)

print("✅ All files have been checked and copied if needed.")

In [None]:
def get_image_paths(dir_path):
    return [
        os.path.join(dir_path, fname)
        for fname in sorted(os.listdir(dir_path))
        if fname.lower().endswith((".png", ".jpg", ".jpeg"))  # handles common formats
    ]


train_drawn_paths = get_image_paths(os.path.join(normalized_train_dir, "drawn"))
train_original_paths = get_image_paths(os.path.join(normalized_train_dir, "original"))

val_drawn_paths = get_image_paths(os.path.join(val_dir, "drawn"))
val_original_paths = get_image_paths(os.path.join(val_dir, "original"))

test_drawn_paths = get_image_paths(os.path.join(normalized_test_dir, "drawn"))
test_original_paths = get_image_paths(os.path.join(normalized_test_dir, "original"))

In [None]:
print(f"Train drawn: {len(train_drawn_paths)}")
print(f"Train original: {len(train_original_paths)}")

print(f"Val drawn: {len(val_drawn_paths)}")
print(f"Val original: {len(val_original_paths)}")

print(f"Test drawn: {len(test_drawn_paths)}")
print(f"Test original: {len(test_original_paths)}")

In [None]:
import tensorflow as tf

def load_image_pair(drawn_path, original_path):
    def _decode_and_preprocess(path):
        image = tf.io.read_file(path)
        image = tf.image.decode_png(image, channels=3)
        image = tf.image.convert_image_dtype(image, tf.float32)  # [0,1] instead of casting
        image = tf.image.resize(image, [512,512])
        return image

    drawn_image = _decode_and_preprocess(drawn_path)
    original_image = _decode_and_preprocess(original_path)

    # Apply augmentation with probability
    if tf.random.uniform([]) > 0.5:
        drawn_image = tf.image.flip_left_right(drawn_image)
        original_image = tf.image.flip_left_right(original_image)

    if tf.random.uniform([]) > 0.7:
        drawn_image = tf.image.random_brightness(drawn_image, max_delta=0.1)
        drawn_image = tf.image.random_contrast(drawn_image, lower=0.9, upper=1.1)

    # Normalize to [-1, 1]
    drawn_image = (drawn_image * 2.0) - 1.0
    original_image = (original_image * 2.0) - 1.0

    return drawn_image, original_image

In [None]:
def augment_data(input_image, target_image):
    # Random jitter (resize and crop)
    stacked = tf.concat([input_image, target_image], axis=2)

    # Random flip
    if tf.random.uniform(()) > 0.5:
        stacked = tf.image.flip_left_right(stacked)

    # Split back
    input_image = stacked[:, :, :3]
    target_image = stacked[:, :, 3:]

    # Adjust brightness/contrast for sketch only
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.random_brightness(input_image, 0.2)

    # Optional: random cropping with resize
    if tf.random.uniform(()) > 0.5:
        # Get dimensions
        height = tf.shape(input_image)[0]
        width = tf.shape(input_image)[1]

        # Random crop size (80-100% of original size)
        crop_size = tf.random.uniform(
            [],
            minval=tf.cast(0.8 * tf.cast(tf.minimum(height, width), tf.float32), tf.int32),
            maxval=tf.cast(tf.minimum(height, width), tf.int32),
            dtype=tf.int32
        )

        # Apply the same crop to both images
        stacked = tf.concat([input_image, target_image], axis=2)
        cropped = tf.image.random_crop(stacked, [crop_size, crop_size, 6])

        # Split and resize back to original dimensions
        input_image = tf.image.resize(cropped[:, :, :3], [256, 256])
        target_image = tf.image.resize(cropped[:, :, 3:], [256, 256])

    return input_image, target_image

In [None]:
# Define constants
BATCH_SIZE = 8
SHUFFLE_BUFFER = 1000

# Function to create dataset from file paths
def build_dataset(drawn_paths, original_paths, training=True):
    # Ensure paths are in list of strings
    drawn_paths = [str(p) for p in drawn_paths]
    original_paths = [str(p) for p in original_paths]

    # Ensure the paths are the same length
    assert len(drawn_paths) == len(original_paths), "Drawn and original paths must have the same length"

    # Create tf.data.Dataset from the file paths
    dataset = tf.data.Dataset.from_tensor_slices((drawn_paths, original_paths))

    # Map the image loading function
    dataset = dataset.map(load_image_pair, num_parallel_calls=tf.data.AUTOTUNE)

    # Shuffle the dataset only for training
    if training:
        dataset = dataset.shuffle(SHUFFLE_BUFFER)

    # Batch the dataset
    dataset = dataset.batch(BATCH_SIZE)

    # Prefetch for performance
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset

In [None]:
# Build datasets for train, validation, and test
train_dataset = build_dataset(train_drawn_paths, train_original_paths, training=True)
val_dataset = build_dataset(val_drawn_paths, val_original_paths, training=False)
test_dataset = build_dataset(test_drawn_paths, test_original_paths, training=False)

# Example usage: Iterate through the train dataset
for images in train_dataset.take(1):
    drawn_images, original_images = images
    print(f"Drawn image batch shape: {drawn_images.shape}")
    print(f"Original image batch shape: {original_images.shape}")

print("✅ Datasets created for train, validation, and test.")

#### Model Defined

In [None]:
import os
import torch
import numpy as np
from tqdm.auto import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from accelerate import Accelerator
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDPMScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import warnings

In [None]:
# Suppress some warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configuration
class Config:
    pretrained_model_name = "runwayml/stable-diffusion-v1-5"  # Base Stable Diffusion model
    output_dir = "sketch_to_image_model"
    resolution = 512 
    train_batch_size = 8
    val_batch_size = 8
    num_train_epochs = 18
    gradient_accumulation_steps = 2
    learning_rate = 1e-4
    lr_scheduler = "constant"
    lr_warmup_steps = 0
    mixed_precision = "no"  # Using full precision to avoid dtype issues
    save_images_epochs = 1
    save_model_epochs = 1
    validation_steps = 50

config = Config()

In [None]:
def add_error_handling():
    def handle_tensor_error(func):
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except RuntimeError as e:
                if "expected scalar type" in str(e) or "type" in str(e):
                    print(f"Data type mismatch error detected: {e}")
                    print("Input tensor types:")
                    for i, arg in enumerate(args):
                        if isinstance(arg, torch.Tensor):
                            print(f"Arg {i}: {arg.dtype}")
                    for k, v in kwargs.items():
                        if isinstance(v, torch.Tensor):
                            print(f"Kwarg {k}: {v.dtype}")
                    raise
                else:
                    raise
        return wrapper

    # Patch key functions to add error handling
    original_conv2d = torch.nn.functional.conv2d
    torch.nn.functional.conv2d = handle_tensor_error(original_conv2d)

In [None]:
# Create dataset class with improved prompts
class SketchToImageDataset(Dataset):
    def __init__(self, sketch_paths, image_paths, tokenizer, transform=None, prompt_engineering=True):
        self.sketch_paths = sketch_paths
        self.image_paths = image_paths
        self.tokenizer = tokenizer
        self.prompt_engineering = prompt_engineering

        if transform:
            self.transform = transform
        else:
            self.transform = transforms.Compose([
                transforms.Resize((config.resolution, config.resolution)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
            ])

        # Using a more descriptive prompt to guide the model better
        self.default_prompt = "a detailed, high-quality photograph of furniture generated from sketch"

        self.furniture_prompts = [
            "a detailed, high-quality photograph of furniture",
            "a professional photograph of home furniture in natural lighting",
            "a realistic photograph of furniture with clean details",
            "a high-resolution image of furniture from sketch",
            "a photorealistic rendering of furniture design",
            "a detailed furniture photograph with accurate textures and materials",
            "a professional product photograph of furniture piece",
            "a clear, detailed image of furniture with realistic details",
        ]

    def __len__(self):
        return len(self.sketch_paths)

    def __getitem__(self, idx):
        # Load sketch and target image
        sketch = Image.open(self.sketch_paths[idx]).convert("RGB")
        target_image = Image.open(self.image_paths[idx]).convert("RGB")

        # Apply transformations
        sketch_tensor = self.transform(sketch)
        target_tensor = self.transform(target_image)

        # Select a prompt (either fixed or random from collection)
        if self.prompt_engineering:
            prompt = np.random.choice(self.furniture_prompts)
        else:
            prompt = self.default_prompt

        # Encode the text prompt
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        )
        text_input_ids = text_inputs.input_ids[0]

        return {
            "sketch": sketch_tensor,
            "target": target_tensor,
            "input_ids": text_input_ids,
            "file_path": self.image_paths[idx]  
        }

In [None]:
# Helper function to create directories if they don't exist
def create_directories():
    os.makedirs(config.output_dir, exist_ok=True)
    os.makedirs(os.path.join(config.output_dir, "samples"), exist_ok=True)
    os.makedirs(os.path.join(config.output_dir, "checkpoints"), exist_ok=True)

# Function to save sample images during training
def save_samples(controlnet, unet, vae, text_encoder, tokenizer, noise_scheduler, sketch_batch, epoch, device, show_images=True):
    # Create pipeline for inference
    pipeline = StableDiffusionControlNetPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        controlnet=controlnet,
        scheduler=noise_scheduler,
        safety_checker=None,
        feature_extractor=None,
        requires_safety_checker=False
    )

    # Move to device
    pipeline = pipeline.to(device)

    # Get a few sketch samples
    sketches = sketch_batch["sketch"].to(device)
    targets = sketch_batch["target"].to(device)
    num_samples = min(4, len(sketches))

    sketches = sketches[:num_samples]
    targets = targets[:num_samples]

    # Prompt
    prompt = ["a detailed, high-quality photograph generated from a sketch"] * num_samples

    # Generate images
    with torch.no_grad():
        images = pipeline(
            prompt=prompt,
            image=sketches,
            num_inference_steps=100,
            guidance_scale=10.0
        ).images

    # Convert tensors and images
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))

    for i in range(num_samples):
        sketch_img = (sketches[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5) * 255
        sketch_img = sketch_img.clip(0, 255).astype(np.uint8)

        target_img = (targets[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5) * 255
        target_img = target_img.clip(0, 255).astype(np.uint8)

        if num_samples > 1:
            axes[i, 0].imshow(sketch_img)
            axes[i, 0].set_title("Sketch Input")
            axes[i, 0].axis("off")

            axes[i, 1].imshow(images[i])
            axes[i, 1].set_title("Generated Image")
            axes[i, 1].axis("off")

            axes[i, 2].imshow(target_img)
            axes[i, 2].set_title("Target Image")
            axes[i, 2].axis("off")
        else:
            axes[0].imshow(sketch_img)
            axes[0].set_title("Sketch Input")
            axes[0].axis("off")

            axes[1].imshow(images[i])
            axes[1].set_title("Generated Image")
            axes[1].axis("off")

            axes[2].imshow(target_img)
            axes[2].set_title("Target Image")
            axes[2].axis("off")

    plt.tight_layout()
    plt.savefig(os.path.join(config.output_dir, "samples", f"sample_epoch_{epoch}.png"))

    if show_images:
        plt.show()  

    plt.close()

    return images

In [None]:
# Modified train_controlnet function with fixed model prediction handling
def train_controlnet(train_sketch_paths, train_image_paths, val_sketch_paths, val_image_paths):
    create_directories()

    # Add custom error handling
    add_error_handling()

    # Initialize accelerator
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs")
    )

    # Initialize logging
    accelerator.init_trackers("controlnet_training")

    # Load the tokenizer
    tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_name, subfolder="tokenizer")

    # Create datasets
    train_dataset = SketchToImageDataset(train_sketch_paths, train_image_paths, tokenizer)
    val_dataset = SketchToImageDataset(val_sketch_paths, val_image_paths, tokenizer)

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=4
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config.val_batch_size,
        shuffle=False,
        num_workers=4
    )

    # Load models
    noise_scheduler = DDPMScheduler.from_pretrained(config.pretrained_model_name, subfolder="scheduler")

    # Load the VAE component
    from diffusers import AutoencoderKL
    vae = AutoencoderKL.from_pretrained(config.pretrained_model_name, subfolder="vae")

    text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_name, subfolder="text_encoder")

    # Load the UNet
    unet = UNet2DConditionModel.from_pretrained(config.pretrained_model_name, subfolder="unet")

    # Create a ControlNet model from the UNet
    controlnet = ControlNetModel.from_unet(unet)

    # Freeze vae and text_encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)  # We only train ControlNet

    # Optimizer for ControlNet with weight decay
    optimizer = torch.optim.AdamW(
        controlnet.parameters(),
        lr=config.learning_rate,
        weight_decay=1e-2  
    )

    # Calculate total number of training steps
    total_train_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
    num_update_steps_per_epoch = len(train_dataloader) // config.gradient_accumulation_steps
    max_train_steps = config.num_train_epochs * num_update_steps_per_epoch

    # Learning rate scheduler
    lr_scheduler = get_scheduler(
        name=config.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
        num_training_steps=max_train_steps,
    )

    # Prepare everything with accelerator
    controlnet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
        controlnet, optimizer, train_dataloader, val_dataloader, lr_scheduler
    )

    # For mixed precision training, we cast the models to the appropriate precision
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Cast all models to the correct dtype
    text_encoder.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)
    unet.to(accelerator.device, dtype=weight_dtype)
    controlnet.to(weight_dtype)

    # We need to keep vae, unet and text_encoder in eval mode
    vae.eval()
    text_encoder.eval()
    unet.eval()

    # Set controlnet to train mode
    controlnet.train()

    # Keep track of losses
    global_step = 0
    best_loss = float('inf')
    best_model_path = None

    for epoch in range(config.num_train_epochs):
        controlnet.train()
        total_loss = 0

        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(controlnet):
                sketch = batch["sketch"].to(accelerator.device, dtype=weight_dtype)
                target = batch["target"].to(accelerator.device, dtype=weight_dtype)
                input_ids = batch["input_ids"].to(accelerator.device)

                # Encode target image to latent space
                latents = vae.encode(target).latent_dist.sample()
                latents = latents * 0.18215

                # Sample noise and add to latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Encode text
                encoder_hidden_states = text_encoder(input_ids)[0]

                controlnet_output = controlnet(noisy_latents, timesteps, encoder_hidden_states, sketch,return_dict=True)
                unet_output = unet(
                    sample=noisy_latents,
                    timestep=timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    down_block_additional_residuals=controlnet_output.down_block_res_samples,
                    mid_block_additional_residual=controlnet_output.mid_block_res_sample,
                    return_dict=True)

                # Extract the noise prediction from the model output
                model_pred = unet_output.sample 

                # For debugging
                if step == 0 and epoch == 0:
                    print(f"Model output type: {type(unet_output)}")
                    print(f"Model prediction type: {type(model_pred)}")
                    print(f"Model prediction shape: {model_pred.shape}")
                    print(f"Noise shape: {noise.shape}")

                # Compute loss
                loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(controlnet.parameters(), 1.0)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

            global_step += 1
            total_loss += loss.item()

            if global_step % config.validation_steps == 0:
                accelerator.print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}")

            if global_step % 100 == 0:  # Every 100 steps, log the loss
                print(f"Epoch [{epoch + 1}/{config.num_train_epochs}], Step [{global_step}], Loss: {total_loss / (step + 1)}")

        avg_loss = total_loss / len(train_dataloader)
        accelerator.print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}")

        # Save the best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_model_path = os.path.join(config.output_dir, "checkpoints", f"best_model_epoch_{epoch}")
            accelerator.save_model(controlnet, best_model_path)
            print(f"New best model saved at epoch {epoch} with loss {best_loss:.4f}")

        # Save images and model
        if epoch % config.save_images_epochs == 0:
            val_batch = next(iter(val_dataloader))
            save_samples(controlnet, unet, vae, text_encoder, tokenizer, noise_scheduler, val_batch, epoch, accelerator.device)

            # Inline image display
            try:
                from IPython.display import Image as IPyImage, display
                sample_path = os.path.join(config.output_dir, "samples", f"sample_epoch_{epoch}.png")
                display(IPyImage(filename=sample_path))
            except ImportError:
                print(f"Sample image saved at {sample_path}")

        if epoch % config.save_model_epochs == 0:
            model_path = os.path.join(config.output_dir, "checkpoints", f"controlnet_epoch_{epoch}")
            accelerator.save_model(controlnet, model_path)

    accelerator.end_training()

    # Return the path to the best model
    if best_model_path is None:
        best_model_path = os.path.join(config.output_dir, "checkpoints", f"controlnet_epoch_{config.num_train_epochs-1}")

    return best_model_path

In [None]:
import traceback

# Function to run inference with trained model
def run_inference(model_path, test_sketch_paths, test_image_paths, num_samples=5):
   
    print(f"Running inference with model from: {model_path}")
    
    # Set the torch dtype consistently
    torch_dtype = torch.float16  
    
    # Initialize the base models from pretrained
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        Config.pretrained_model_name,
        safety_checker=None,
        requires_safety_checker=False,
        controlnet=None,
        torch_dtype=torch_dtype  
    )
    
    # Load the saved controlnet weights manually
    try:
        controlnet = ControlNetModel.from_pretrained(model_path, torch_dtype=torch_dtype)
        pipe.controlnet = controlnet
    except (OSError, ValueError) as e:
        print(f"Could not load model directly: {e}")
        print("Attempting to load model weights manually...")
        
        # Create a new controlnet from the base model
        controlnet = ControlNetModel.from_unet(pipe.unet)
        
        try:
            model_files = [f for f in os.listdir(model_path) if f.endswith('.bin') or f.endswith('.pt')]
            
            if model_files:
                print(f"Found model files: {model_files}")
                state_dict = torch.load(os.path.join(model_path, model_files[0]), map_location="cpu")
                
                # Check if we got a state_dict directly or need to extract it
                if not isinstance(state_dict, dict) or "state_dict" in state_dict:
                    if "state_dict" in state_dict:
                        state_dict = state_dict["state_dict"]
                    elif hasattr(state_dict, "state_dict"):
                        state_dict = state_dict.state_dict()
                
                # Filter out unwanted keys if necessary
                filtered_state_dict = {k: v for k, v in state_dict.items() 
                                      if k.startswith('controlnet.') or not k.startswith('model.')}
                
                # Load the state dict
                missing, unexpected = controlnet.load_state_dict(filtered_state_dict, strict=False)
                
                if missing:
                    print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}")
                if unexpected:
                    print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
                    
                # Explicitly convert to the correct dtype
                controlnet = controlnet.to(torch_dtype)
                pipe.controlnet = controlnet
                print("Successfully loaded model weights manually.")
            else:
                print(f"No model files found in {model_path}. Looking for safetensors...")
                
                # Look for safetensors files
                safetensors_files = [f for f in os.listdir(model_path) if f.endswith('.safetensors')]
                
                if safetensors_files:
                    print(f"Found safetensors files: {safetensors_files}")
                    # Use from_pretrained with the safetensors file
                    controlnet = ControlNetModel.from_pretrained(
                        Config.pretrained_model_name, 
                        subfolder="controlnet", 
                        resume_download=True,
                        local_files_only=False,
                        torch_dtype=torch_dtype
                    )
                    
                    # Try to load the state dict
                    from safetensors import safe_open
                    safetensors_path = os.path.join(model_path, safetensors_files[0])
                    with safe_open(safetensors_path, framework="pt", device="cpu") as f:
                        state_dict = {k: f.get_tensor(k) for k in f.keys()}
                    
                    missing, unexpected = controlnet.load_state_dict(state_dict, strict=False)
                    pipe.controlnet = controlnet.to(torch_dtype)
                    print("Successfully loaded model from safetensors.")
                else:
                    raise ValueError(f"No model files found in {model_path}")
        except Exception as e:
            print(f"Failed to load model weights manually: {e}")
            print("Falling back to using base ControlNet model...")
            controlnet = ControlNetModel.from_pretrained(
                "lllyasviel/sd-controlnet-canny",  # Use a pretrained ControlNet as fallback
                torch_dtype=torch_dtype
            )
            pipe.controlnet = controlnet
    
    # Move model to GPU
    pipe = pipe.to("cuda")
    
    # Use DDPM scheduler for better quality
    pipe.scheduler = DDPMScheduler.from_pretrained(Config.pretrained_model_name, subfolder="scheduler")
    
    # Create a directory for the results
    results_dir = "inference_results"
    os.makedirs(results_dir, exist_ok=True)
    
    # Randomly sample test sketches
    indices = np.random.choice(len(test_sketch_paths), min(num_samples, len(test_sketch_paths)), replace=False)
    
    # Create a transform for preprocessing sketches that matches the model's expected dtype
    transform = transforms.Compose([
        transforms.Resize((config.resolution, config.resolution)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    
    # Create a counter for successful generations
    successful_samples = 0
    all_comparison_figs = []
    
    # Generate images for each sample
    for i, idx in enumerate(indices):
        # Load sketch and target image
        sketch_path = test_sketch_paths[idx]
        target_path = test_image_paths[idx]
        
        print(f"Processing sketch: {sketch_path}")
        print(f"Target image: {target_path}")
        
        # Load sketch and target image
        sketch = Image.open(sketch_path).convert("RGB")
        target = Image.open(target_path).convert("RGB")
        
        # Resize both for visualization
        sketch = sketch.resize((config.resolution, config.resolution))
        target = target.resize((config.resolution, config.resolution))
        
        # Generate image
        prompt = "a detailed, high-quality photograph generated from a sketch"
        try:
            image = pipe(
                prompt,
                image=sketch,
                num_inference_steps=50,
                guidance_scale=7.5
            ).images[0]
            
            # Save individual images for comparison
            sketch.save(os.path.join(results_dir, f"sample_{successful_samples}_sketch.png"))
            image.save(os.path.join(results_dir, f"sample_{successful_samples}_generated.png"))
            target.save(os.path.join(results_dir, f"sample_{successful_samples}_target.png"))
            
            # Save the results
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            
            axes[0].imshow(sketch)
            axes[0].set_title("Input Sketch")
            axes[0].axis("off")
            
            axes[1].imshow(image)
            axes[1].set_title("Generated Image")
            axes[1].axis("off")
            
            axes[2].imshow(target)
            axes[2].set_title("Ground Truth")
            axes[2].axis("off")
            
            plt.tight_layout()
            plt.savefig(os.path.join(results_dir, f"comparison_{successful_samples}.png"))
            plt.close()
            
            all_comparison_figs.append(fig)
            successful_samples += 1
            
            print(f"Saved comparison image: comparison_{successful_samples-1}.png")
        except Exception as e:
            print(f"Error generating image {i}: {e}")
            print(f"Detailed error: {traceback.format_exc()}")
    
    # Create a summary image with all comparisons
    if successful_samples > 0:
        try:
            fig, axes = plt.subplots(successful_samples, 3, figsize=(15, 5 * successful_samples))
            
            for i in range(successful_samples):
                try:
                    sketch = Image.open(os.path.join(results_dir, f"sample_{i}_sketch.png"))
                    generated = Image.open(os.path.join(results_dir, f"sample_{i}_generated.png"))
                    target = Image.open(os.path.join(results_dir, f"sample_{i}_target.png"))
                    
                    if successful_samples > 1:
                        axes[i, 0].imshow(sketch)
                        axes[i, 0].set_title(f"Sketch {i+1}")
                        axes[i, 0].axis("off")
                        
                        axes[i, 1].imshow(generated)
                        axes[i, 1].set_title(f"Generated {i+1}")
                        axes[i, 1].axis("off")
                        
                        axes[i, 2].imshow(target)
                        axes[i, 2].set_title(f"Target {i+1}")
                        axes[i, 2].axis("off")
                    else:
                        axes[0].imshow(sketch)
                        axes[0].set_title("Sketch")
                        axes[0].axis("off")
                        
                        axes[1].imshow(generated)
                        axes[1].set_title("Generated")
                        axes[1].axis("off")
                        
                        axes[2].imshow(target)
                        axes[2].set_title("Target")
                        axes[2].axis("off")
                except Exception as e:
                    print(f"Could not include image {i} in summary: {e}")
            
            plt.tight_layout()
            plt.savefig(os.path.join(results_dir, "all_comparisons.png"))
            plt.close()
            
            print(f"All inference results saved to: {results_dir}")
            print(f"Summary image saved as: {os.path.join(results_dir, 'all_comparisons.png')}")
        except Exception as e:
            print(f"Error creating summary image: {e}")
            print(f"Detailed error: {traceback.format_exc()}")
    else:
        print("No successful image generations to create summary.")

    return results_dir

In [None]:
if __name__ == "__main__":
    train_sketch_paths = train_drawn_paths
    train_image_paths = train_original_paths
    val_sketch_paths = val_drawn_paths
    val_image_paths = val_original_paths
    test_sketch_paths = test_drawn_paths
    test_image_paths = test_original_paths

    # Train the model
    best_model_path = train_controlnet(train_sketch_paths, train_image_paths, val_sketch_paths, val_image_paths)

    # Run inference with the best model (not just the final model)
    run_inference(
        best_model_path,  
        test_sketch_paths,
        test_image_paths,
        num_samples=5,
    )