In [None]:
import torch
import torchvision
import diffusers
import einops
import safetensors
import albumentations
import transformers
import matplotlib
import numpy
from torchvision import transforms


  from .autonotebook import tqdm as notebook_tqdm


Torch version: 2.5.1+cu121
Torchvision version: 0.20.1+cu121
Diffusers version: 0.31.0
Einops version: 0.8.1
Safetensors version: 0.5.3
Albumentations version: 2.0.4
Transformers version: 4.30.0
Matplotlib version: 3.10.0
NumPy version: 1.26.4


In [None]:
# Standard imports
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import os
# Hugging Face Hub import

# Diffusers-specific imports
from diffusers import StableDiffusionPipeline, DDIMScheduler

# Custom modules
from models import UNETLatentEdgePredictor, SketchSimplificationNetwork
from pipeline import SketchGuidedText2Image



In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load Sketch Simplifier

In [None]:
# Configure and load sketch simplification network 

sketch_simplifier = SketchSimplificationNetwork().to(device)
sketch_simplifier.load_state_dict(torch.load("models-checkpoints/model_gan.pth"))

sketch_simplifier.eval()
sketch_simplifier.requires_grad_(False)

# Load Stable Diffusian Model and schdueler for Infernce 

In [None]:
# Load Stable Diffusion Pipeline
stable_diffusion_1_5 = "benjamin-paine/stable-diffusion-v1-5"

In [None]:
stable_diffusion=StableDiffusionPipeline.from_pretrained(
    stable_diffusion_1_5,
    torch_dtype=torch.float16,
    safety_checker=None  # Skip the safety checker if it's not required
)
vae = stable_diffusion.vae.to(device)
unet = stable_diffusion.unet.to(device)
tokenizer = stable_diffusion.tokenizer
text_encoder = stable_diffusion.text_encoder.to(device) 

vae.eval()
unet.eval()
text_encoder.eval()
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)

In [None]:
import numpy 
# Set Scheduler
noise_scheduler = DDIMScheduler(
        beta_start = 0.00085,
        beta_end = 0.012,
        beta_schedule = "scaled_linear",
        num_train_timesteps = 1000,
        clip_sample = False,
    )

# Unet Pipeline and model

In [None]:
# Load U-Net latent edge predictor
checkpoint = torch.load("models-checkpoints/unet_latent_edge_predictor_checkpoint.pt",map_location=torch.device('cpu'))

LEP_UNET = UNETLatentEdgePredictor(9320, 4, 9).to(device)
LEP_UNET.load_state_dict(checkpoint["model_state_dict"])

LEP_UNET.train()

# Setting up dataset


In [None]:
# Define function to convert images to VAE latent embeddings
def encode_image_to_latent(image_path, vae, device):
    image = Image.open(image_path).convert("L")  # Convert to grayscale
    transform = transforms.ToTensor()
    image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension

    with torch.no_grad():
        latent = vae.encode(image_tensor).latent_dist.sample()  # Get VAE latent
    return latent

In [None]:
# Define paths to dataset
sketch_dir = "Lego 256x256/sketches"
image_dir = "Lego 256x256/images"

# Get list of sketches
sketch_files = sorted(os.listdir(sketch_dir))

# Store latents for training
sketch_latents = []
lego_latents = []
time_embeddings = []

for file in sketch_files:
    sketch_path = os.path.join(sketch_dir, file)
    image_path = os.path.join(image_dir, file)  # Assume matching filenames

    # Encode sketch and Lego image into latent space
    sketch_latent = encode_image_to_latent(sketch_path, vae, device)
    lego_latent = encode_image_to_latent(image_path, vae, device)

    # Generate a random time embedding (for diffusion guidance)
    time_embedding = torch.rand(1, sketch_latent.shape[-1]).to(device)

    # Store latents
    sketch_latents.append(sketch_latent)
    lego_latents.append(lego_latent)
    time_embeddings.append(time_embedding)

# Convert lists to tensors
sketch_latents = torch.stack(sketch_latents)
lego_latents = torch.stack(lego_latents)
time_embeddings = torch.stack(time_embeddings)

print(f"Loaded {len(sketch_latents)} samples")


# Loss



In [None]:
import torch.nn as nn

# Define Mean Squared Error (MSE) loss function
criterion = nn.MSELoss()

# Function to compute loss
def compute_loss(pred, target):
    return criterion(pred, target)

# training 

In [None]:
import torch.optim as optim

# Set training parameters
num_epochs = 5
learning_rate = 1e-4
batch_size = 8

# Initialize optimizer
optimizer = optim.AdamW(LEP_UNET.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    running_loss = 0.0
    
    for i in range(0, len(sketch_latents), batch_size):
        # Get batch
        batch_sketch = sketch_latents[i:i+batch_size].to(device)
        batch_lego = lego_latents[i:i+batch_size].to(device)
        batch_time = time_embeddings[i:i+batch_size].to(device)

        # Combine sketch latents with time embeddings
        input_embedding = torch.cat([batch_sketch.flatten(start_dim=1), batch_time.flatten(start_dim=1)], dim=1)

        # Forward pass through LEP UNet
        optimizer.zero_grad()
        output_latent = LEP_UNET(input_embedding)  # Predict latent

        # Compute loss
        loss = compute_loss(output_latent, batch_lego)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    # Print loss for epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(sketch_latents):.4f}")

# Save fine-tuned model
torch.save(LEP_UNET.state_dict(), "LEP_UNET_finetuned.pth")
print("Model fine-tuned and saved!")


# Test  Model Interface


In [None]:
LEP_UNET.load_state_dict(torch.load("LEP_UNET_finetuned.pth"))
LEP_UNET.eval()

print("Fine-tuned model loaded for inference.")


In [None]:
# Initialize Text-guided Text-to-Image synthesis pipeline

pipeline = SketchGuidedText2Image(stable_diffusion_pipeline = stable_diffusion, 
                                  unet = unet, vae = vae, 
                                  text_encoder = text_encoder, 
                                  lep_unet = LEP_UNET, scheduler = noise_scheduler, 
                                  tokenizer = tokenizer,
                                  sketch_simplifier = sketch_simplifier,
                                  device = device)

In [None]:
edge_maps = [Image.open("example-sketches/home.jpg")]
seed = 1000

inverse_diffusion = pipeline.Inference(
    prompt=[" Snail in its Shell in the street with many cars "],
    num_images_per_prompt=1,
    edge_maps=edge_maps,
    negative_prompt=None,
    num_inference_timesteps=50,
    classifier_guidance_strength=8,
    sketch_guidance_strength=1.6,
    seed=seed,
    simplify_edge_maps=True,
    guidance_steps_perc=0.5,
)

In [None]:
for edge_map, image in zip(edge_maps, inverse_diffusion["generated_image"]):
    fig, axs = plt.subplots(1, 2, figsize = (10, 5))
    axs[0].imshow(edge_map)
    axs[1].imshow(image)
    axs[0].axis("off")
    axs[1].axis("off")
    axs[0].set_title("Input Sketch")
    axs[1].set_title("Synthesized Image")