In [1]:
# Standard imports
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
# Hugging Face Hub import

# Diffusers-specific imports
from diffusers import StableDiffusionPipeline, DDIMScheduler
from peft import get_peft_model, LoraConfig

# Custom modules

from models import UNETLatentEdgePredictor, SketchSimplificationNetwork
from pipeline import SketchGuidedText2Image

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import random
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm



  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load Sketch Simplifier

In [3]:
# Configure and load sketch simplification network 

sketch_simplifier = SketchSimplificationNetwork().to(device)
sketch_simplifier.load_state_dict(torch.load("models-checkpoints/model_gan.pth"))

sketch_simplifier.eval()
sketch_simplifier.requires_grad_(False)

  sketch_simplifier.load_state_dict(torch.load("models-checkpoints/model_gan.pth"))


SketchSimplificationNetwork(
  (0): Conv2d(1, 48, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  (1): ReLU()
  (2): Conv2d(48, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU()
  (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): ReLU()
  (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (7): ReLU()
  (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU()
  (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU()
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (13): ReLU()
  (14): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU()
  (16): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU()
  (18): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (19): ReLU()
  (20): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1

# Load Stable Diffusian Model and schdueler for Infernce 

In [4]:
# Load Stable Diffusion Pipeline
stable_diffusion_1_5 = "benjamin-paine/stable-diffusion-v1-5"

In [18]:
from diffusers import StableDiffusionPipeline

pipeline = StableDiffusionPipeline.from_pretrained("benjamin-paine/stable-diffusion-v1-5", use_auth_token=True)
pipeline.save_pretrained("./stable-diffusion-v1-5")

Keyword arguments {'use_auth_token': True} are not expected by StableDiffusionPipeline and will be ignored.
Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
Loading pipeline components...: 100%|██████████| 7/7 [00:05<00:00,  1.37it/s]


In [5]:
stable_diffusion=StableDiffusionPipeline.from_pretrained(
    stable_diffusion_1_5,
    torch_dtype=torch.float16,
    safety_checker=None  # Skip the safety checker if it's not required
)
vae = stable_diffusion.vae.to(device)
unet = stable_diffusion.unet.to(device)
tokenizer = stable_diffusion.tokenizer
text_encoder = stable_diffusion.text_encoder.to(device)

vae.eval()
unet.eval()
text_encoder.eval()

text_encoder.requires_grad_(False)

Loading pipeline components...: 100%|██████████| 6/6 [00:10<00:00,  1.76s/it]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

In [6]:
import numpy 
# Set Scheduler
noise_scheduler = DDIMScheduler(
        beta_start = 0.00085,
        beta_end = 0.012,
        beta_schedule = "scaled_linear",
        num_train_timesteps = 1000,
        clip_sample = False,
    )

# Unet Pipeline and model

In [7]:
# Load U-Net latent edge predictor
checkpoint = torch.load("models-checkpoints/unet_latent_edge_predictor_checkpoint.pt", map_location=torch.device('cpu'))
LEP_UNET = UNETLatentEdgePredictor(9320, 4, 9).to(device)
LEP_UNET.load_state_dict(checkpoint["model_state_dict"])
LEP_UNET.eval()
LEP_UNET.requires_grad_(False)

  checkpoint = torch.load("models-checkpoints/unet_latent_edge_predictor_checkpoint.pt", map_location=torch.device('cpu'))


UNETLatentEdgePredictor(
  (e1): encoder_block(
    (conv): convolutional_block(
      (conv1): Conv2d(9320, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (e2): encoder_block(
    (conv): convolutional_block(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (

# apply LoRA to the U-Net

In [8]:
# Apply LoRA to the U-Net
lora_config = LoraConfig(
    r=16,  # Rank of LoRA matrix
    lora_alpha=32,  # Scaling factor
    target_modules=["to_q", "to_k", "to_v"],  # Target attention layers
    lora_dropout=0  # Dropout
)

unet = get_peft_model(unet, lora_config)

# Setting up dataset


In [9]:
# Define image transformations
image_size = 512
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

In [10]:
import os
from torch.utils.data import Dataset
from PIL import Image

class SketchImageDataset(Dataset):
    def __init__(self, root_dir, image_transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, "photos")
        self.sketch_dir = os.path.join(root_dir, "sketch")
        # Collect all image filenames with common image extensions.
        self.image_filenames = [f for f in os.listdir(self.image_dir) if f.endswith((".jpg", ".png"))]
        self.image_transform = image_transform  # transform for photos only

    def clean_filename(self, filename):
        """
        Cleans the filename for use as a text prompt.
          - Removes file extension,
          - Splits on "-" and uses the first part,
          - Replaces underscores with spaces.
        """
        name = os.path.splitext(filename)[0]
        name = name.split("-")[0]
        name = name.replace("_", " ")
        return name.strip()

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        # Load the photo image and convert to RGB.
        image_name = self.image_filenames[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)  # Expecting a tensor in [0,1]

        # For sketches, we want to keep them as PIL images so that our trainer’s image_to_latents can use numpy.
        # We assume each photo has a corresponding folder in the "sketch" directory.
        sketch_folder = os.path.join(self.sketch_dir, os.path.splitext(image_name + ".jpg")[0])
        if not os.path.exists(sketch_folder):
            raise FileNotFoundError(f"Sketch folder not found: {sketch_folder}")
        sketch_files = [f for f in os.listdir(sketch_folder) if f.endswith((".jpg", ".png"))]
        if not sketch_files:
            raise ValueError(f"No sketches found for {image_name}")
        sketch_path = os.path.join(sketch_folder, sketch_files[0])  # Use the first sketch
        sketch = Image.open(sketch_path).convert("RGB")  # Keep sketch as PIL image and convert them to grey scale

        # Clean the filename to generate a text prompt.
        text_prompt = self.clean_filename(image_name)

        return {"image": image, "sketch": sketch, "text_prompt": text_prompt, "filename": image_name}


In [11]:
# Initialize dataset and dataloader
photo_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])
def custom_collate_fn(batch):
    # Batch is a list of dictionaries.
    images = torch.stack([item["image"] for item in batch])  # stack photo tensors
    text_prompts = [item["text_prompt"] for item in batch]    # leave as list of strings
    filenames = [item["filename"] for item in batch]          # leave as list
    sketches = [item["sketch"] for item in batch]             # leave sketches as a list of PIL images
    return {"image": images, "sketch": sketches, "text_prompt": text_prompts, "filename": filenames}


# Initialize the dataset.
dataset = SketchImageDataset(root_dir="Lego_256x256", image_transform=photo_transform)

# Create the DataLoader.
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=custom_collate_fn)


# Loss



In [12]:
import torch.nn as nn

# Define Mean Squared Error (MSE) loss function
criterion = nn.MSELoss()

# Function to compute loss
def compute_loss(pred, target):
    return criterion(pred, target)

# training 

In [13]:
# Noise Scheduler
noise_scheduler = DDIMScheduler(
    beta_start=0.00085, beta_end=0.012,
    beta_schedule="scaled_linear",
    num_train_timesteps=1000, clip_sample=False
)

# Optimizer
optimizer = optim.AdamW(unet.parameters(), lr=1e-4)


In [14]:
def text_to_embeddings(text):
    """
    Generates text embeddings using the CLIP text encoder.
    """
    tokenized_text = tokenizer(
        text,
        padding="max_length",
        max_length=77,  # Standard max length for CLIP
        truncation=True,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        text_embeddings = text_encoder(tokenized_text.input_ids)[0]  # No `.half()` here

    return text_embeddings.float()  # ✅ Convert to float32 to match U-Net


In [15]:
import importlib
import pipeline
import TrainingPipeline
importlib.reload(TrainingPipeline)
from TrainingPipeline import SketchGuidedText2ImageTrainer  # Import again


# Initialize trainer
trainer = SketchGuidedText2ImageTrainer(
    stable_diffusion_pipeline=stable_diffusion,
    unet=unet,
    vae=vae,
    text_encoder=text_encoder,
    lep_unet=LEP_UNET,
    scheduler=noise_scheduler,
    tokenizer=tokenizer,
    sketch_simplifier=sketch_simplifier,
    device=device
)


In [17]:
props = torch.cuda.get_device_properties(0)
print(f"Total GPU Memory: {props.total_memory / (1024**3):.2f} GiB")


Total GPU Memory: 8.00 GiB


In [16]:
# Define optimizer
optimizer = optim.AdamW(unet.parameters(), lr=1e-4)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images = batch["image"].to(device)
        sketches= batch["sketch"]
        text_prompts = batch["text_prompt"]

        # Run training step
        loss = trainer.train_step(images, sketches, text_prompts, optimizer, noise_scheduler)

    print(f"Epoch {epoch+1}: Loss {loss}")

  deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
Denoising:   0%|          | 0/50 [00:52<?, ?it/s]
Epoch 1/5:   0%|          | 0/11 [00:54<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.14 GiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 13.26 GiB is allocated by PyTorch, and 114.50 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Test  Model Interface


In [None]:
# Save fine-tuned model
unet.save_pretrained("fine_tuned_unet")

# Download the model in Jupyter Notebook
import shutil
shutil.make_archive("fine_tuned_unet", 'zip', "fine_tuned_unet")

# To download it locally
from IPython.display import FileLink
FileLink("fine_tuned_unet.zip")



In [None]:
# Initialize Text-guided Text-to-Image synthesis pipeline

# Reinitialize the pipeline with the fine-tuned model
pipeline = SketchGuidedText2Image(
    stable_diffusion_pipeline=stable_diffusion, 
    unet=unet, vae=vae, 
    text_encoder=text_encoder, 
    lep_unet=LEP_UNET, scheduler=noise_scheduler, 
    tokenizer=tokenizer,
    sketch_simplifier=sketch_simplifier,
    device=device
)

In [None]:
edge_maps = [Image.open("example-sketches/home.jpg")]
seed = 1000

inverse_diffusion = pipeline.Inference(
    prompt=[" Snail in its Shell in the street with many cars "],
    num_images_per_prompt=1,
    edge_maps=edge_maps,
    negative_prompt=None,
    num_inference_timesteps=50,
    classifier_guidance_strength=8,
    sketch_guidance_strength=1.6,
    seed=seed,
    simplify_edge_maps=True,
    guidance_steps_perc=0.5,
)

In [None]:
for edge_map, image in zip(edge_maps, inverse_diffusion["generated_image"]):
    fig, axs = plt.subplots(1, 2, figsize = (10, 5))
    axs[0].imshow(edge_map)
    axs[1].imshow(image)
    axs[0].axis("off")
    axs[1].axis("off")
    axs[0].set_title("Input Sketch")
    axs[1].set_title("Synthesized Image")