In [None]:
"""
!pip uninstall -y numpy
!pip install numpy==1.26.4
import os
os.kill(os.getpid(), 9)

import numpy as np
import torch
print("PyTorch:", torch.__version__)
print("NumPy:", np.__version__)
"""

In [None]:
!pip install numpy==1.26.4
!pip install torch==2.2.0 torchvision==0.17.0 --index-url https://download.pytorch.org/whl/cu121
!pip install diffusers==0.31.0 transformers==4.30.0 peft==0.5.0 einops safetensors>=0.3.1 albumentations matplotlib kornia

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

# Create the models-checkpoints directory if it doesn't exist.
os.makedirs("models-checkpoints", exist_ok=True)

# Install gdown if it's not already installed.
!pip install --upgrade gdown

# Download the Unet Latent Edge Predictor checkpoint.
# This command downloads the checkpoint into the models-checkpoints directory with the filename "unet_latent_edge_predictor_checkpoint.pt"
!gdown https://drive.google.com/uc?id=1w7eimErXnnRrjY6TY8yXrmno-hvcZ94_ -O models-checkpoints/unet_latent_edge_predictor_checkpoint.pt

# Download the Sketch Simplification Network checkpoint.
# Replace <URL_SKETCH> with the actual URL for model_gan.pth.
!cp "/content/drive/MyDrive/models-checkpoints/model_gan.pth" "./models-checkpoints/model_gan.pth"

print("Environment setup complete. Checkpoints downloaded to 'models-checkpoints'.")

In [None]:
# Standard imports
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn as nn
from torch.utils.data import random_split
import numpy
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
# Hugging Face Hub import

# Diffusers-specific imports
from diffusers import StableDiffusionPipeline, DDIMScheduler
from peft import get_peft_model, LoraConfig, PeftModel

# Custom modules

from models import UNETLatentEdgePredictor, SketchSimplificationNetwork
from pipeline import SketchGuidedText2Image

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import random
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Load Sketch Simplifier

In [None]:
# Configure and load sketch simplification network

sketch_simplifier = SketchSimplificationNetwork().to(device)
sketch_simplifier.load_state_dict(torch.load("models-checkpoints/model_gan.pth"))

sketch_simplifier.eval()
sketch_simplifier.requires_grad_(False)

# Load Stable Diffusian Model and schdueler for Infernce

In [None]:
from dotenv import load_dotenv
import os

#!huggingface-cli login
# Load Stable Diffusion Pipeline
token =  "Your token"# Replace with your actual token
pipeline = StableDiffusionPipeline.from_pretrained("benjamin-paine/stable-diffusion-v1-5", use_auth_token=token)

In [None]:
# Load Stable Diffusion Pipeline
stable_diffusion_1_5 = "benjamin-paine/stable-diffusion-v1-5"
stable_diffusion=StableDiffusionPipeline.from_pretrained(
    stable_diffusion_1_5,
    torch_dtype=torch.float16,
    safety_checker=None  # Skip the safety checker if it's not required
)
vae = stable_diffusion.vae.to(device)
unet = stable_diffusion.unet.to(device)
tokenizer = stable_diffusion.tokenizer
text_encoder = stable_diffusion.text_encoder.to(device)

vae.eval()
unet.eval()
text_encoder.eval()

text_encoder.requires_grad_(False)

# Unet Pipeline and model

In [None]:
# Load U-Net latent edge predictor
checkpoint = torch.load("models-checkpoints/unet_latent_edge_predictor_checkpoint.pt", map_location=torch.device('cpu'))
LEP_UNET = UNETLatentEdgePredictor(9320, 4, 9).to(device)
LEP_UNET.load_state_dict(checkpoint["model_state_dict"])
LEP_UNET.eval()
LEP_UNET.requires_grad_(False)

# Apply Lora Finetuning

In [None]:
#for name, module in unet.named_modules():
#    print(name)


In [None]:
# Apply LoRA to the U-Net
lora_config = LoraConfig(
    r=16,  # Rank of LoRA matrix
    lora_alpha=32,  # Scaling factor
    target_modules=[

        "conv_out",  # Fine-tune final UNet output
        "up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q",
        "up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k",
        "up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v"

    ],
    lora_dropout=0.1  # No dropout
)

unet = get_peft_model(unet, lora_config)

In [None]:

for name, param in unet.named_parameters():
    if param.requires_grad:
        print(f"LoRA is fine-tuning: {name} | Shape: {param.shape}")



In [None]:
for name, param in unet.named_parameters():
    if "lora" not in name:  # Only allow LoRA layers to update
        param.requires_grad = False

# Setting up dataset


In [None]:
import os
from torch.utils.data import Dataset
from PIL import Image

class SketchImageDataset(Dataset):
    def __init__(self, root_dir, mapping_file=None, image_transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, "photos")
        self.sketch_dir = os.path.join(root_dir, "sketch")
        # Collect all image filenames with common image extensions.
        self.image_filenames = [f for f in os.listdir(self.image_dir) if f.endswith((".jpg", ".png"))]
        self.image_transform = image_transform  # transform for photos only

        # Load mapping from file if provided.
        # Expect each line in the mapping file to be formatted as "key: value"
        self.mapping = {}
        print(os.path.exists(mapping_file))
        if mapping_file is not None and os.path.exists(mapping_file):
            with open(mapping_file, 'r') as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue  # Skip empty lines.
                    if ":" in line:
                        key, value = line.split(":", 1)
                        self.mapping[key.strip()] = value.strip()

    def clean_filename(self, filename):
        """
        Removes file extension and extra characters.
        """
        return os.path.splitext(filename)[0]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):


        # Load the photo.
        image_name = self.image_filenames[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)

        base_name = self.clean_filename(image_name)
        # Assume each photo has a corresponding folder named "base_name.jpg" in the sketch directory.
        sketch_folder = os.path.join(self.sketch_dir, base_name + ".jpg")
        if not os.path.exists(sketch_folder):
            raise FileNotFoundError(f"Sketch folder not found: {sketch_folder}")

        # Instead of selecting a single sketch, load all sketches that contain "sketchs32strokes"
        sketch_files = [f for f in os.listdir(sketch_folder)
                        if f.endswith((".jpg", ".png")) and ("sketchs32strokes" in f or "sketchs20strokes" in f)]
        if not sketch_files:
            raise ValueError(f"No sketches with 'sketchs32strokes' found for {image_name}")
        # Load all matching sketches into a list of PIL images.
        sketches = [Image.open(os.path.join(sketch_folder, f)).convert("RGB") for f in sketch_files]

        # Construct the text prompt.
        text_prompt = "Lego " + self.mapping.get(base_name, base_name) + " with a white background"

        return {"image": image, "sketches": sketches, "text_prompt": text_prompt, "filename": image_name}




In [None]:

from torch.utils.data import random_split, DataLoader
# Initialize dataset and dataloader
image=512
photo_transform = transforms.Compose([
    transforms.Resize((image,image)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

def custom_collate_fn(batch):
    images = torch.stack([item["image"] for item in batch])  # Assuming images are tensors.
    text_prompts = [item["text_prompt"] for item in batch]
    filenames = [item["filename"] for item in batch]
    # sketches remains as a list of lists of PIL images.
    sketches = [item["sketches"] for item in batch]
    return {"image": images, "sketch": sketches, "text_prompt": text_prompts, "filename": filenames}


# Initialize the dataset.
dataset = SketchImageDataset(
    root_dir="/content/drive/MyDrive/Lego_256x256/Lego_256x256",
    mapping_file="mappings.txt",
    image_transform=photo_transform
)


dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = int(0.1 * dataset_size)
test_size = dataset_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)
print(f"📊 Dataset split sizes: Train={train_size}, Validation={val_size}, Test={test_size}")



# Noise/ Optimizer


In [None]:
for name, param in unet.named_parameters():
    if "lora" not in name:
        param.requires_grad = False

In [None]:
import numpy as np
import torch
print(torch.__version__)
print(np.__version__)

In [None]:
# Noise Scheduler
noise_scheduler = DDIMScheduler(
        beta_start = 0.00085,
        beta_end = 0.012,
        beta_schedule = "scaled_linear",
        num_train_timesteps = 1000,
        clip_sample = False,
    )


# Optimizer
# Define parameter groups with separate learning rates
param_groups = [
    {'params': [], 'lr': 1e-3},  # higher LR for conv_out
    {'params': [], 'lr': 5e-4}   # lower LR for attention layers
]

# Assign parameters to groups explicitly
for name, param in unet.named_parameters():
    if param.requires_grad:
        if "conv_out" in name:
            param_groups[0]['params'].append(param)
        else:  # all other LoRA layers
            param_groups[1]['params'].append(param)

# Setup AdamW with separate parameter groups
optimizer = torch.optim.AdamW(
    param_groups,
    weight_decay=0.01,
    eps=1e-8
)

# Pipeline Setup

In [None]:
import importlib
import pipeline
import FinalTrainingPipeline
importlib.reload(FinalTrainingPipeline)
from FinalTrainingPipeline import SketchDiffusionTrainer  # Import again


# Initialize trainer
trainer = SketchDiffusionTrainer(
    stable_diffusion_pipeline=stable_diffusion,
    unet=unet,
    vae=vae,
    sketch_adapter=LEP_UNET,
    scheduler=noise_scheduler,
    tokenizer=tokenizer,
    device=device,
    null_prob=0.2
)


#  Training

In [None]:
sketch_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),                   # [0,1], shape (C,H,W)
    transforms.Normalize([0.5]*3, [0.5]*3)   # map to [-1,1]
])

In [None]:
import csv
import time
best_loss = float('inf')
best_val_loss = float('inf')
best_runtime = None
start_time = time.time()
accumulation_steps=10
micro_step = 0
num_epochs=20
seed=1000
sampler_edge_maps = [Image.open("/content/drive/MyDrive/Lego_256x256/Lego_256x256/sketch/Car-1.jpg/sketchs32strokes_Car-1.jpg")]

# Open your CSV once
with open("training_log.csv","w",newline="") as log_f:
    writer = csv.writer(log_f)
    writer.writerow(["epoch","micro_step","sketch_id","train_loss","lr","elapsed_s"])

    for epoch in range(num_epochs):
        optimizer.zero_grad()              # ensure grads start at zero
        #scheduler.step()
        for i, batch in enumerate(train_dataloader, start=1):
            images    = batch["image"]      # [B, C, H, W]
            sketch_ls = batch["sketch"]   # list of lists of PILs
            prompts   = batch["text_prompt"]

            # Loop over each sample in the batch (often B=1)
            for img, sketches, prompt in zip(images, sketch_ls, prompts):
                img = img.unsqueeze(0).to(device)

                for sketch_id, sketch_pil in enumerate(sketches):

                    # 2) Forward + backward
                    loss = trainer.train_step(
                        img, [sketch_pil], [prompt], optimizer
                    )
                    micro_step += 1

                    # 3) Logging
                    elapsed = time.time() - start_time
                    lr = optimizer.param_groups[0]["lr"]
                    writer.writerow([
                        epoch+1,
                        micro_step,
                        sketch_id,
                        f"{loss:.4f}",
                        f"{lr:.2e}",
                        f"{elapsed:.1f}"
                    ])

                    # 4) Update once every accumulation_steps
                    if micro_step % accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()

        # 5) Catch any leftover gradients if micro_step % accumulation_steps != 0
        if micro_step % accumulation_steps != 0:
            optimizer.step()
            optimizer.zero_grad()

        # --- 1) Single-step MSE eval ---
        # Prepare a small batch from your validation set:
        val_batch = next(iter(val_dataloader))
        val_images = val_batch["image"].to(device)
        # pick first sketch of each
        val_sketches = [sk_list[0] for sk_list in val_batch["sketches"]]
        # preprocess to tensor
        val_sketch_tensors = torch.stack([
            sketch_transform(sk).to(device) for sk in val_sketches
        ])
        val_prompts = val_batch["text_prompt"]

        val_mse = trainer.eval_step(val_images, val_sketch_tensors, val_prompts)
        print(f"Epoch {epoch+1} ► val MSE: {val_mse:.4f}")

        # log it
        with open("eval_log.csv", "a", newline="") as eval_f:
            eval_writer = csv.writer(eval_f)
            eval_writer.writerow([epoch+1, f"{val_mse:.4f}", "", "", ""])

    # --- 2) Full sampling every 5 epochs ---
    if (epoch + 1) % 5 == 0:
        sampler = SketchGuidedText2Image(
            stable_diffusion_pipeline=stable_diffusion,
            unet=unet,
            vae=vae,
            text_encoder=text_encoder,
            lep_unet=LEP_UNET,
            scheduler=noise_scheduler,
            tokenizer=tokenizer,
            sketch_simplifier=sketch_simplifier,
            device=device
        )
        samples = sampler.Inference(
                  num_images_per_prompt=1,
                  edge_maps=sampler_edge_maps,
                  negative_prompt=None,
                  num_inference_timesteps=50,
                  classifier_guidance_strength=8,
                  sketch_guidance_strength=1.6,
                  seed= seed,
                  simplify_edge_maps=False,
                  guidance_steps_perc=0.5,
              )["generated_image"]

        # save & log each
        for sketch_id, out_img in enumerate(samples):
            fname = f"samples/epoch{epoch+1}_sketch{sketch_id}.png"
            out_img.save(fname)
            with open("eval_log.csv", "a", newline="") as eval_f:
                eval_writer = csv.writer(eval_f)
                eval_writer.writerow([epoch+1, "", epoch+1, sketch_id, fname])





1: Steps=20, Cls=6, Sketch=1.6, Perc=0.5, Image Cosine=0.909, Edge Cosine=0.909, Avg Cosine=0.909

2: Steps=20, Cls=6, Sketch=1.6, Perc=0.6, Image Cosine=0.898, Edge Cosine=0.898, Avg Cosine=0.898

3: Steps=20, Cls=6, Sketch=1.8, Perc=0.5, Image Cosine=0.895, Edge Cosine=0.895, Avg Cosine=0.895

4: Steps=20, Cls=4, Sketch=1.8, Perc=0.5, Image Cosine=0.892, Edge Cosine=0.892, Avg Cosine=0.892

5: Steps=20, Cls=4, Sketch=1.0, Perc=0.6, Image Cosine=0.888, Edge Cosine=0.888, Avg Cosine=0.888

6: Steps=20, Cls=4, Sketch=1.8, Perc=0.4, Image Cosine=0.888, Edge Cosine=0.888, Avg Cosine=0.888

7: Steps=20, Cls=4, Sketch=1.0, Perc=0.7, Image Cosine=0.885, Edge Cosine=0.885, Avg Cosine=0.885

8: Steps=20, Cls=6, Sketch=1.0, Perc=0.7, Image Cosine=0.883, Edge Cosine=0.883, Avg Cosine=0.883

9: Steps=20, Cls=6, Sketch=1.6, Perc=0.7, Image Cosine=0.880, Edge Cosine=0.880, Avg Cosine=0.880

10: Steps=20, Cls=6, Sketch=1.0, Perc=0.6, Image Cosine=0.879, Edge Cosine=0.879, Avg Cosine=0.879


# Save model


#Download Folder

In [None]:
!zip -r 3.4MultipleSketchaccum10epoch18.zip 3.4MultipleSketchaccum10epoch18

In [None]:
from google.colab import files
files.download("3.4MultipleSketchaccum10epoch18.zip")

In [None]:
#unet.save_pretrained(f"{runtime}/best_unet_epoch{epoch+1}")
from peft import PeftModel

stable_diffusion=StableDiffusionPipeline.from_pretrained(
    stable_diffusion_1_5,
    torch_dtype=torch.float16,
    safety_checker=None
)
vae = stable_diffusion.vae.to(device)
unet = stable_diffusion.unet.to(device)
tokenizer = stable_diffusion.tokenizer
text_encoder = stable_diffusion.text_encoder.to(device)

#Load LoRA adapters
unet = PeftModel.from_pretrained(unet, f"./3.4MultipleSketchaccum10epoch18/best_unet_epoch{best_runtime}")


# Evaluate  Model


In [None]:
import importlib
import pipeline
import TrainingPipelineTest
importlib.reload(TrainingPipelineTest)
from TrainingPipelineTest import SketchGuidedText2ImageTrainer  # Import again


# Initialize trainer
eval = SketchGuidedText2ImageTrainer(
    stable_diffusion_pipeline=stable_diffusion,
    unet=unet,
    vae=vae,
    text_encoder=text_encoder,
    lep_unet=LEP_UNET,
    scheduler=noise_scheduler,
    tokenizer=tokenizer,
    sketch_simplifier=sketch_simplifier,
    device=device
)

In [None]:
import torch
import csv

# Evaluation CSV Logging
runtime = runtime
os.makedirs(f"/content/{runtime}", exist_ok=True)

# Validation phase
total_val_mse = 0.0
total_val_vgg = 0.0
total_val_white = 0.0
total_val_ssim = 0.0
total_val_edge = 0.0
total_samples = 0
unet.eval()

with torch.no_grad():  # No gradients computed for validation
  val_mse, val_ssim, val_edge_loss = 0, 0, 0
  total_epoch_val_loss = 0.0
  total_epoch_guidance_loss = 0.0
  total_samples = 0
  for batch in val_dataloader:
    images, sketches, text_prompts,filenames = (
          batch["image"].to(device),
          batch["sketch"],
          batch["text_prompt"],
          batch["filename"]
      )
    # Iterate over each image, its sketches, and text prompt
    for img, sketches, prompt, filename in zip(images, sketches, text_prompts, filenames):
        img = img.unsqueeze(0)  # Add batch dimension for single image
        losses_per_sketch = []
        guidance_losses_per_sketch = []
        mse_list, vgg_list, white_list, ssim_list, edge_list = [], [], [], [], []
        iteration=0
        # Evaluate all sketches individually
        for sketch in sketches:
            val_loss, guidance_loss, ssim_loss, edge_loss, whiteloss, vgg, generated_image, edgemap  = trainer.eval(
                images=img,
                sketches=[sketch],  # Pass single sketch as a list
                text_prompts=[prompt],
                optimizer=None,  # No optimizer step during eval
                noise_scheduler=noise_scheduler,
                num_inference_timesteps=num_inference_timesteps,
                classifier_guidance_strength=8,
                sketch_guidance_strength=1.8,
                guidance_steps_perc=0.5,
                seed=None
            )
            mse_list.append(val_loss.item())
            vgg_list.append(vgg.item())
            white_list.append(whiteloss.item())
            ssim_list.append(ssim_loss.item())
            edge_list.append(edge_loss.item())
            generated_image.save(f'output_images/sketch: {iteration} name: {filename}.png')
            sketches=sketches+1


        # Average losses per image (across sketches)
        avg_mse_img = sum(mse_list) / len(mse_list)
        avg_vgg_img = sum(vgg_list) / len(vgg_list)
        avg_white_img = sum(white_list) / len(white_list)
        avg_ssim_img = sum(ssim_list) / len(ssim_list)
        avg_edge_img = sum(edge_list) / len(edge_list)

        # Correct accumulation here:
        total_val_mse += avg_mse_img
        total_val_vgg += avg_vgg_img
        total_val_white += avg_white_img
        total_val_ssim += avg_ssim_img
        total_val_edge += avg_edge_img

        total_samples += 1

    # Correctly compute epoch averages:
    avg_val_mse = total_val_mse / total_samples
    avg_val_vgg = total_val_vgg / total_samples
    avg_val_white = total_val_white / total_samples
    avg_val_ssim = total_val_ssim / total_samples
    avg_val_edge = total_val_edge / total_samples
    # Compute total validation loss with dynamic weighting
    validation_total_loss = (
        avg_val_mse*msew +
        vggw * avg_val_vgg+ 0.1 * avg_val_white
    )

    #print(f"Epoch {epoch+1} | Total Validation Loss: {validation_total_loss:.4f} | MSE: {epoch_avg_val_loss:.4f} | SSIM: {epoch_avg_ssim_loss:.4f} | Edge Loss: {epoch_avg_edge_loss:.4f}")
    print(f"Epoch {epoch+1} | Total Validation Loss: {validation_total_loss:.4f} | "
    f"MSE: {avg_val_mse:.4f} |VGG Loss: {avg_val_vgg:.4f} | SSIM: {avg_val_ssim:.4f} | "
    f"Edge Loss: {avg_val_edge:.4f} ")


# Log epoch summary clearly
with open(f"/content/{runtime}/epoch_eval_summary.csv", "w", newline='') as epoch_summary_file:
    summary_writer = csv.writer(epoch_summary_file)
    summary_writer.writerow(["Epoch", "MSE Loss", "VGG Loss", "White Loss", "Total Loss", "SSIM", "Edge"])
    summary_writer.writerow([1,avg_val_mse, avg_val_vgg, avg_val_white,
            validation_total_loss, avg_val_ssim, avg_val_edge
        ])


print(f"🎯 Testing completed ")



In [None]:
!zip -r 3.4MultipleSketchaccum10epoch18_2.zip 3.4MultipleSketchaccum10epoch18

In [None]:
from google.colab import files
files.download("3.4MultipleSketchaccum10epoch18_2.zip")

In [None]:
from google.colab import runtime
runtime.unassign()