In [None]:
"""
!pip uninstall -y numpy
!pip install numpy==1.26.4
import os
os.kill(os.getpid(), 9)

import numpy as np
import torch
print("PyTorch:", torch.__version__)
print("NumPy:", np.__version__)
"""



In [None]:
!pip install -q numpy==1.26.4
!pip install -q torch==2.2.0 torchvision==0.17.0 --index-url https://download.pytorch.org/whl/cu121
!pip install -q diffusers==0.31.0 transformers==4.30.0 peft==0.5.0 einops safetensors>=0.3.1 albumentations matplotlib kornia

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

# Create the models-checkpoints directory if it doesn't exist.
os.makedirs("models-checkpoints", exist_ok=True)

# Install gdown if it's not already installed.
!pip install --upgrade gdown

# Download the Unet Latent Edge Predictor checkpoint.
# This command downloads the checkpoint into the models-checkpoints directory with the filename "unet_latent_edge_predictor_checkpoint.pt"
!gdown https://drive.google.com/uc?id=1w7eimErXnnRrjY6TY8yXrmno-hvcZ94_ -O models-checkpoints/unet_latent_edge_predictor_checkpoint.pt

# Download the Sketch Simplification Network checkpoint.
# Replace <URL_SKETCH> with the actual URL for model_gan.pth.
!cp "/content/drive/MyDrive/models-checkpoints/model_gan.pth" "./models-checkpoints/model_gan.pth"

print("Environment setup complete. Checkpoints downloaded to 'models-checkpoints'.")

In [None]:
# Standard imports
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn as nn
from torch.utils.data import random_split
import numpy
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
# Hugging Face Hub import

# Diffusers-specific imports
from diffusers import StableDiffusionPipeline, DDIMScheduler
from peft import get_peft_model, LoraConfig, PeftModel

# Custom modules

from models import UNETLatentEdgePredictor, SketchSimplificationNetwork
from pipeline import SketchGuidedText2Image

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import random
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Load Sketch Simplifier

In [None]:
# Configure and load sketch simplification network

sketch_simplifier = SketchSimplificationNetwork().to(device)
sketch_simplifier.load_state_dict(torch.load("models-checkpoints/model_gan.pth"))

sketch_simplifier.eval()
sketch_simplifier.requires_grad_(False)

# Load Stable Diffusian Model and schdueler for Infernce

In [None]:
from dotenv import load_dotenv
import os

#!huggingface-cli login
# Load Stable Diffusion Pipeline
token =  "Your token"# Replace with your actual token
pipeline = StableDiffusionPipeline.from_pretrained("benjamin-paine/stable-diffusion-v1-5", use_auth_token=token)

In [None]:
# Load Stable Diffusion Pipeline
stable_diffusion_1_5 = "benjamin-paine/stable-diffusion-v1-5"
stable_diffusion=StableDiffusionPipeline.from_pretrained(
    stable_diffusion_1_5,
    torch_dtype=torch.float16,
    safety_checker=None  # Skip the safety checker if it's not required
)
vae = stable_diffusion.vae.to(device)
unet = stable_diffusion.unet.to(device)
tokenizer = stable_diffusion.tokenizer
text_encoder = stable_diffusion.text_encoder.to(device)

vae.eval()
unet.eval()
text_encoder.eval()

text_encoder.requires_grad_(False)

# Unet Pipeline and model

In [None]:
# Load U-Net latent edge predictor
checkpoint = torch.load("models-checkpoints/unet_latent_edge_predictor_checkpoint.pt", map_location=torch.device('cpu'))
LEP_UNET = UNETLatentEdgePredictor(9320, 4, 9).to(device)
LEP_UNET.load_state_dict(checkpoint["model_state_dict"])
LEP_UNET.eval()
LEP_UNET.requires_grad_(False)

# Apply Lora Finetuning

In [None]:
# Apply LoRA to the U-Net
lora_config = LoraConfig(
    r=8,  # Rank of LoRA matrix
    lora_alpha= 16,  # Scaling factor
    target_modules=[
        "conv_out"  # Fine-tune final UNet output
       # ,"up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q",
       #"up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k",
       # "up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v"
           ,"mid_block.resnets.0.conv1",
            "mid_block.resnets.0.conv2",
            "mid_block.attentions.0.to_q",
            "mid_block.attentions.0.to_v",
    ],
    lora_dropout=0.1
)

unet = get_peft_model(unet, lora_config)

In [None]:

for name, param in unet.named_parameters():
    if param.requires_grad:
        print(f"LoRA is fine-tuning: {name} | Shape: {param.shape}")



In [None]:
for name, param in unet.named_parameters():
    if "lora" not in name:  # Only allow LoRA layers to update
        param.requires_grad = False

# Setting up dataset


In [None]:
import os
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import random_split, DataLoader

class SketchImageDataset(Dataset):
    def __init__(self, root_dir, mapping_file=None, image_transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, "photos")
        self.sketch_dir = os.path.join(root_dir, "sketch")
        # Collect all image filenames with common image extensions.
        self.image_filenames = [f for f in os.listdir(self.image_dir) if f.endswith((".jpg", ".png"))]
        self.image_transform = image_transform  # transform for photos only

        # Load mapping from file if provided.
        # Expect each line in the mapping file to be formatted as "key: value"
        self.mapping = {}
        print(os.path.exists(mapping_file))
        if mapping_file is not None and os.path.exists(mapping_file):
            with open(mapping_file, 'r') as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue  # Skip empty lines.
                    if ":" in line:
                        key, value = line.split(":", 1)
                        self.mapping[key.strip()] = value.strip()

    def clean_filename(self, filename):
        """
        Removes file extension and extra characters.
        """
        return os.path.splitext(filename)[0]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):


        # Load the photo.
        image_name = self.image_filenames[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)

        base_name = self.clean_filename(image_name)
        # Assume each photo has a corresponding folder named "base_name.jpg" in the sketch directory.
        sketch_folder = os.path.join(self.sketch_dir, base_name)
        if not os.path.exists(sketch_folder):
            raise FileNotFoundError(f"Sketch folder not found: {sketch_folder}")

        # Instead of selecting a single sketch, load all sketches that contain "sketchs32strokes"
        sketch_files = [f for f in os.listdir(sketch_folder)
                        if f.endswith((".jpg", ".png")) and ("sketchs32strokes" in f or "sketchs20strokes" in f)]
        if not sketch_files:
            raise ValueError(f"No sketches with 'sketchs32strokes' found for {image_name}")
        # Load all matching sketches into a list of PIL images.
        sketches = [Image.open(os.path.join(sketch_folder, f)).convert("RGB") for f in sketch_files]

        # Construct the text prompt.
        #text_prompt = "Lego " + self.mapping.get(base_name, base_name) + " with a white background"
        text_prompt = "Lego " + self.mapping.get(base_name, base_name) +" with a empty background/surrounding"

        return {"image": image, "sketches": sketches, "text_prompt": text_prompt, "filename": image_name}




In [None]:

from torch.utils.data import random_split, DataLoader
# Initialize dataset and dataloader
image=512
photo_transform = transforms.Compose([
    transforms.Resize((image,image)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


def custom_collate_fn(batch):
    images = torch.stack([item["image"] for item in batch])  # Assuming images are tensors.
    text_prompts = [item["text_prompt"] for item in batch]
    filenames = [item["filename"] for item in batch]
    # sketches remains as a list of lists of PIL images.
    sketches = [item["sketches"] for item in batch]
    return {"image": images, "sketch": sketches, "text_prompt": text_prompts, "filename": filenames}


# Initialize the dataset.
dataset = SketchImageDataset(
    root_dir="/content/drive/MyDrive/Lego_256x256/combined",
    mapping_file="mappings.txt",
    image_transform=photo_transform
)


dataset_size = len(dataset)

train_size = int(1 * dataset_size)
val_size = int(0 * dataset_size)
test_size = dataset_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Noise/ Optimizer


In [None]:
for name, param in unet.named_parameters():
    if "lora" not in name:
        param.requires_grad = False

In [None]:
import numpy as np
import torch
print(torch.__version__)
print(np.__version__)

In [None]:
# Noise Scheduler
noise_scheduler = DDIMScheduler(
        beta_start = 0.00085,
        beta_end = 0.012,
        beta_schedule = "scaled_linear",
        num_train_timesteps = 1000,
        clip_sample = False,
    )


# Optimizer
# Define parameter groups with separate learning rates
#lrchange=0.5
param_groups = [
    {'params': [], 'lr': ((1e-3))},  # higher LR for conv_out
    {'params': [], 'lr': ((5e-4))}   # lower LR for attention layers
]

# Assign parameters to groups explicitly
for name, param in unet.named_parameters():
    if param.requires_grad:
        if "conv_out" in name:
            param_groups[0]['params'].append(param)
        else:  # all other LoRA layers
            param_groups[1]['params'].append(param)

# Setup AdamW with separate parameter groups
optimizer = torch.optim.AdamW(
    param_groups,
    weight_decay=0.01,
    eps=1e-8
)

# Pipeline Setup

In [None]:
import importlib
import pipeline
import TrainingPipelineTest
importlib.reload(TrainingPipelineTest)
from TrainingPipelineTest import SketchGuidedText2ImageTrainer  # Import again


# Initialize trainer
trainer = SketchGuidedText2ImageTrainer(
    stable_diffusion_pipeline=stable_diffusion,
    unet=unet,
    vae=vae,
    text_encoder=text_encoder,
    lep_unet=LEP_UNET,
    scheduler=noise_scheduler,
    tokenizer=tokenizer,
    sketch_simplifier=sketch_simplifier,
    device=device
)


#  Training

In [None]:
import csv
accumulation_steps = 8
optimizer.zero_grad()
num_inference_timesteps=25
num_epochs = 15
seed=1000
runtime=f"non_overlay{accumulation_steps}epoch{num_epochs}With8loramidblocks"

# CSV Files
training_log_csv = os.path.join(runtime, "training_log.csv")
image_loss_csv = os.path.join(runtime, "image_loss.csv")
epoch_log_csv = os.path.join(runtime, "epoch_avg_loss.csv")
gpu_usage_csv = os.path.join(runtime, "gpu_usage.csv")
detail_loss_csv = f"{runtime}/detail_loss.csv"
weightscsv = os.path.join(runtime, "weightchange.csv")

# Create directories
os.makedirs(f"{runtime}", exist_ok=True)

# Initialize CSV files
with open(training_log_csv, "w", newline="") as f:
    csv.writer(f).writerow(["Epoch", "Iteration", "Accumulated Loss", "Average Loss", "Avg VGG Loss", "Avg SSIM", "Avg Edge"])

with open(epoch_log_csv, "w", newline="") as f:
    csv.writer(f).writerow(["Epoch", "Avg MSE", "Avg VGG", "Avg SSIM", "Avg Edge"])

with open(gpu_usage_csv, "w", newline="") as f:
    csv.writer(f).writerow(["Epoch", "Iteration", "Avg GPU Usage (MB)"])

with open(image_loss_csv, "w", newline="") as f:
    csv.writer(f).writerow(["Epoch", "Image Filename", "MSE", "Scaled Loss", "VGG", "White", "SSIM", "Edge"])

# Initialize CSV logging
with open(detail_loss_csv, "w", newline="") as f:
    csv.writer(f).writerow(["Epoch", "MSE Loss", "VGG Loss", "White Loss", "Total Loss", "SSIM", "Edge"])

# Initialize weight logging
with open(weightscsv, "w", newline="") as f:
    csv.writer(f).writerow(["Epoch","VGG weight"])


In [None]:
# 2) Path for the unified split file
split_csv = os.path.join(runtime, "split.csv")

# 3) Dump filenames + split label
with open(split_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Filename", "Split"])
    # -- TRAIN --
    for sample in train_dataset:
        writer.writerow([ sample["filename"], "train" ])
    # -- VAL --
    for sample in val_dataset:
        writer.writerow([ sample["filename"], "val" ])
    # -- TEST --
    for sample in test_dataset:
        writer.writerow([ sample["filename"], "test" ])

# 4) Build your dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True,  collate_fn=custom_collate_fn)
val_dataloader   = DataLoader(val_dataset,   batch_size=1, shuffle=False, collate_fn=custom_collate_fn)
test_dataloader  = DataLoader(test_dataset,  batch_size=1, shuffle=False, collate_fn=custom_collate_fn)


In [None]:
best_loss = float('inf')
best_val_loss = float('inf')
best_runtime = None
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
ssimw, edgew, msew, vggw = 0.5, 0.5, 1.0, 0.35
prev_ssim_loss, prev_edge_loss, prev_mse_loss, prev_vgg_loss = 0, 0, 0, 0
os.makedirs(f'{runtime}/output_images', exist_ok=True)
for epoch in range(num_epochs):
    accumulated_loss,accumulated_vgg_loss,accumulated_ssim_loss,accumulated_edge_loss, total_epoch_loss = 0.0, 0.0, 0.0, 0.0,0.0
    gpu_usage_accumulated, accumulation_counter = 0.0, 0
    total_epoch_ssim_loss, total_epoch_edge_loss, total_epoch_vgg_loss = 0.0, 0.0, 0.0
    unet.train()

    for i, batch in enumerate(train_dataloader):
        images, sketches, text_prompts, filenames = batch["image"].to(device), batch["sketch"], batch["text_prompt"], batch["filename"]
        loss, guidance_loss, gpu_usage, ssim_loss, edge_loss, white_loss, vgg = trainer.train_step(
            images, sketches, text_prompts, optimizer, noise_scheduler,
            num_inference_timesteps=num_inference_timesteps,
            classifier_guidance_strength=6,
            sketch_guidance_strength=1.6,
            guidance_steps_perc=0.6,
            seed=seed,
        )

        total_loss = (loss * msew + vggw*vgg
        #+0.1*white_loss
                      )
        batch_loss = total_loss / accumulation_steps
        batch_loss.backward()

        # Accumulate loss statistics
        accumulated_loss += loss.item()
        accumulated_vgg_loss += vgg.item()
        accumulated_ssim_loss += ssim_loss.item()
        accumulated_edge_loss += edge_loss.item()


        total_epoch_loss += loss.item()
        total_epoch_vgg_loss += vgg.item()
        total_epoch_ssim_loss += ssim_loss.item()
        total_epoch_edge_loss += edge_loss.item()
        gpu_usage_accumulated += torch.cuda.max_memory_allocated(device) / (1024**2)

        # Log per-image losses
        with open(image_loss_csv, "a", newline="") as f:
            for filename in filenames:
                csv.writer(f).writerow([
                    epoch + 1, filename,
                    loss.item(), batch_loss.item(),
                    vgg.item(), white_loss.item(),
                    ssim_loss.item(), edge_loss.item()
                ])


        accumulation_counter+= 1
        # Perform optimizer step after accumulation steps
        if (i + 1) % accumulation_steps == 0 or accumulation_counter == accumulation_steps:
            avg_loss = accumulated_loss / accumulation_steps
            avg_vgg_loss = accumulated_vgg_loss / accumulation_steps
            avg_ssim_loss = accumulated_ssim_loss / accumulation_steps
            avg_edge_loss = accumulated_edge_loss / accumulation_steps
            avg_gpu_usage = gpu_usage_accumulated / accumulation_steps

            torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
            # Dynamic weight adjustment
            """
            if prev_ssim_loss is not None and prev_edge_loss is not None:
                if avg_ssim_loss >= prev_ssim_loss:
                    ssimw *= 1.05
                elif avg_ssim_loss < prev_ssim_loss * 0.9:
                    ssimw *= 0.95

                if avg_ssim_loss < prev_ssim_loss and avg_edge_loss < prev_edge_loss:
                    edgew *= 1.05
                elif avg_edge_loss > prev_edge_loss * 1.1:
                    edgew *= 0.95

                if avg_loss < prev_mse_loss * 0.95:
                    msew *= 0.95

                if avg_vgg_loss >= prev_vgg_loss:
                    vggw *= 1.05
                elif avg_vgg_loss < prev_vgg_loss * 0.9:
                    vggw *= 0.95

            prev_ssim_loss, prev_edge_loss, prev_mse_loss, prev_vgg_loss = avg_ssim_loss, avg_edge_loss, avg_loss, avg_vgg_loss
            print(f"epoch {epoch+1} ssimw: {ssimw}, edgew{edgew}, mse_weight{msew}")


            with open(weightscsv, "a", newline="") as f:
                csv.writer(f).writerow([epoch + 1,ssimw, edgew, msew])

            """
            # Write accumulated losses to training_log.csv
            with open(training_log_csv, "a", newline="") as f:
                csv.writer(f).writerow([
                    epoch + 1, i + 1,
                    accumulated_loss, avg_loss,
                    avg_vgg_loss, avg_ssim_loss, avg_edge_loss
                ])


            # Log average GPU usage
            with open(gpu_usage_csv, "a", newline="") as f:
                csv.writer(f).writerow([epoch + 1, i + 1, avg_gpu_usage])

            # Reset accumulated values
            accumulated_loss = accumulated_vgg_loss = accumulated_ssim_loss = accumulated_edge_loss = 0.0
            gpu_usage_accumulated = 0.0
            accumulation_counter = 0

            print(f"Epoch {epoch+1} Iteration {i+1} | Avg Loss: {avg_loss:.4f} | Avg VGG: {avg_vgg_loss} |SSIM: {avg_ssim_loss:.4f} | Edge Loss: {avg_edge_loss:.4f} | GPU Usage: {avg_gpu_usage:.2f} MB")

            optimizer.step()
            optimizer.zero_grad()

    optimizer.step()
    optimizer.zero_grad()
    # Compute final epoch averages
    epoch_avg_train_loss = total_epoch_loss / len(train_dataloader)
    epoch_avg_ssim_loss = total_epoch_ssim_loss / len(train_dataloader)
    epoch_avg_edge_loss = total_epoch_edge_loss / len(train_dataloader)
    epoch_avg_vgg_loss = total_epoch_vgg_loss / len(train_dataloader)
    with open(epoch_log_csv, "a", newline="") as f:
        csv.writer(f).writerow([epoch + 1, epoch_avg_train_loss, epoch_avg_vgg_loss, epoch_avg_ssim_loss, epoch_avg_edge_loss])

    print(f"Epoch {epoch+1} | Training Avg Loss: {epoch_avg_train_loss:.4f} | vgg {epoch_avg_vgg_loss}  |avr ssim {epoch_avg_ssim_loss:.4f} | avrg edge {epoch_avg_edge_loss:.4f}")

    # Validation phase
    total_val_mse = 0.0
    total_val_vgg = 0.0
    total_val_white = 0.0
    total_val_ssim = 0.0
    total_val_edge = 0.0
    total_samples = 0
    unet.eval()

    with torch.no_grad():  # No gradients computed for validation
        val_mse, val_ssim, val_edge_loss = 0, 0, 0
        total_epoch_val_loss = 0.0
        total_epoch_guidance_loss = 0.0
        total_samples = 0
        for batch in val_dataloader:
          images, sketches, text_prompts,filenames = (
                batch["image"].to(device),
                batch["sketch"],
                batch["text_prompt"],
                batch["filename"]
            )
          # Iterate over each image, its sketches, and text prompt
          for img, sketches, prompt, filename in zip(images, sketches, text_prompts, filenames):
              img = img.unsqueeze(0)  # Add batch dimension for single image
              losses_per_sketch = []
              guidance_losses_per_sketch = []
              mse_list, vgg_list, white_list, ssim_list, edge_list = [], [], [], [], []
              iteration=0
              # Evaluate all sketches individually
              for sketch in sketches:
                  val_loss, guidance_loss, ssim_loss, edge_loss, whiteloss, vgg, generated_image, edgemap  = trainer.eval(
                      images=img,
                      sketches=[sketch],  # Pass single sketch as a list
                      text_prompts=[prompt],
                      optimizer=None,  # No optimizer step during eval
                      noise_scheduler=noise_scheduler,
                      num_inference_timesteps=50,
                      classifier_guidance_strength=8,
                      sketch_guidance_strength=1.6,
                      guidance_steps_perc=0.5,
                      seed=None
                  )
                  mse_list.append(val_loss.item())
                  vgg_list.append(vgg.item())
                  white_list.append(whiteloss.item())
                  ssim_list.append(ssim_loss.item())
                  edge_list.append(edge_loss.item())
                  # Ensure the per-epoch folder exists
                  epoch_dir = f"{runtime}/output_images/{epoch+1}"
                  if not os.path.exists(epoch_dir):
                      os.makedirs(epoch_dir, exist_ok=True)

                  # Then save the image
                  generated_image.save(f"{epoch_dir}/{filename}_{iteration}.png")
                  iteration=iteration+1


              # Average losses per image (across sketches)
              avg_mse_img = sum(mse_list) / len(mse_list)
              avg_vgg_img = sum(vgg_list) / len(vgg_list)
              avg_white_img = sum(white_list) / len(white_list)
              avg_ssim_img = sum(ssim_list) / len(ssim_list)
              avg_edge_img = sum(edge_list) / len(edge_list)

              # Correct accumulation here:
              total_val_mse += avg_mse_img
              total_val_vgg += avg_vgg_img
              total_val_white += avg_white_img
              total_val_ssim += avg_ssim_img
              total_val_edge += avg_edge_img

              total_samples += 1

        # Correctly compute epoch averages:
        avg_val_mse = total_val_mse / total_samples
        avg_val_vgg = total_val_vgg / total_samples
        avg_val_white = total_val_white / total_samples
        avg_val_ssim = total_val_ssim / total_samples
        avg_val_edge = total_val_edge / total_samples
        # Compute total validation loss with dynamic weighting
        validation_total_loss = (
            avg_val_mse*msew +
            vggw * avg_val_vgg
            #+0.1 * avg_val_white
        )

        #print(f"Epoch {epoch+1} | Total Validation Loss: {validation_total_loss:.4f} | MSE: {epoch_avg_val_loss:.4f} | SSIM: {epoch_avg_ssim_loss:.4f} | Edge Loss: {epoch_avg_edge_loss:.4f}")
        print(f"Epoch {epoch+1} | Total Validation Loss: {validation_total_loss:.4f} | "
          f"MSE: {avg_val_mse:.4f} |VGG Loss: {avg_val_vgg:.4f} | SSIM: {avg_val_ssim:.4f} | "
          f"Edge Loss: {avg_val_edge:.4f} ")

        # Log validation detail loss
        with open(detail_loss_csv, "a", newline="") as f:
            csv.writer(f).writerow([
                epoch + 1,
                avg_val_mse, avg_val_vgg, avg_val_white,
                validation_total_loss, avg_val_ssim, avg_val_edge
            ])

    """
    # Adjust weights dynamically
    if prev_ssim_loss is not None and prev_edge_loss is not None and prev_vgg_loss is not None:
        # Adjust SSIM weight
        if epoch_avg_ssim_loss >= prev_ssim_loss:
            ssimw *= 1.05
        elif epoch_avg_ssim_loss < prev_ssim_loss * 0.9:
            ssimw *= 0.95

        # Adjust Edge weight
        if epoch_avg_ssim_loss < prev_ssim_loss and epoch_avg_edge_loss < prev_edge_loss:
            edgew *= 1.05
        elif epoch_avg_edge_loss > prev_edge_loss * 1.1:
            edgew *= 0.95

        # Adjust VGG weight
        if epoch_avg_vgg_loss >= prev_vgg_loss:
            vggw *= 1.05
        elif epoch_avg_vgg_loss < prev_vgg_loss * 0.9:
            vggw *= 0.95

    prev_ssim_loss, prev_edge_loss, prev_vgg_loss = epoch_avg_ssim_loss, epoch_avg_edge_loss, epoch_avg_vgg_loss

    # Log weight changes
    with open(weightscsv, "a", newline="") as f:
        csv.writer(f).writerow([epoch + 1, ssimw, edgew, msew, vggw])
    """

    # Choose best model based on validation loss only
    if validation_total_loss  < best_val_loss:
        best_val_loss = validation_total_loss
        best_runtime = epoch+1
        unet.save_pretrained(f"{runtime}/best_unet_epoch{epoch+1}")
        print(f"✅ New Best Model Saved at Epoch {epoch+1} (Total Loss: {validation_total_loss:.4f})")
    else:
        unet.save_pretrained(f"{runtime}/unet_epoch{epoch+1}")
    """
    if epoch % 3 == 0:  # every 5 epochs
      if epoch_avg_vgg_loss > prev_vgg_loss:
          vggw *= 1.03  # slight increase
      elif epoch_avg_vgg_loss < prev_vgg_loss * 0.95:
          vggw *= 0.97  # slight decrease
      with open(weightscsv, "a", newline="") as f:
        csv.writer(f).writerow([epoch + 1, vggw])
      prev_vgg_loss = epoch_avg_vgg_loss
    """

   # scheduler.step(validation_total_loss)


1: Steps=20, Cls=6, Sketch=1.6, Perc=0.5, Image Cosine=0.909, Edge Cosine=0.909, Avg Cosine=0.909

2: Steps=20, Cls=6, Sketch=1.6, Perc=0.6, Image Cosine=0.898, Edge Cosine=0.898, Avg Cosine=0.898

3: Steps=20, Cls=6, Sketch=1.8, Perc=0.5, Image Cosine=0.895, Edge Cosine=0.895, Avg Cosine=0.895

4: Steps=20, Cls=4, Sketch=1.8, Perc=0.5, Image Cosine=0.892, Edge Cosine=0.892, Avg Cosine=0.892

5: Steps=20, Cls=4, Sketch=1.0, Perc=0.6, Image Cosine=0.888, Edge Cosine=0.888, Avg Cosine=0.888

6: Steps=20, Cls=4, Sketch=1.8, Perc=0.4, Image Cosine=0.888, Edge Cosine=0.888, Avg Cosine=0.888

7: Steps=20, Cls=4, Sketch=1.0, Perc=0.7, Image Cosine=0.885, Edge Cosine=0.885, Avg Cosine=0.885

8: Steps=20, Cls=6, Sketch=1.0, Perc=0.7, Image Cosine=0.883, Edge Cosine=0.883, Avg Cosine=0.883

9: Steps=20, Cls=6, Sketch=1.6, Perc=0.7, Image Cosine=0.880, Edge Cosine=0.880, Avg Cosine=0.880

10: Steps=20, Cls=6, Sketch=1.0, Perc=0.6, Image Cosine=0.879, Edge Cosine=0.879, Avg Cosine=0.879


In [None]:
"""
import shutil
import os
dir_path = 'Test2epoch18andLoraMidblocks2'
if os.path.isdir(dir_path):
    shutil.rmtree(dir_path)
    print(f"Removed directory {dir_path}")
else:
    print("Directory does not exist.")
"""


# Clipasso LIPSL


In [None]:
!pip install -q git+https://github.com/openai/CLIP.git lpips pytorch-fid

In [None]:
import os
import csv
import torch
import clip
import lpips
from PIL import Image
from torchvision import transforms

# —– Setup device and models —–
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
lpips_fn = lpips.LPIPS(net="vgg").to(device)

# —– Paths —–
# Main experiment folders
base_dirs = [
    "non_overlay8epoch15With8loramidblocks",
]
photos_dir = "/content/drive/MyDrive/Lego_256x256/overlay/photos"

# Helpers
def load_image(path):
    return Image.open(path).convert("RGB")

def prep_for_lpips(img, size=256):
    t = transforms.Resize(size)(img)
    t = transforms.CenterCrop(224)(t)
    t = transforms.ToTensor()(t).unsqueeze(0) * 2 - 1
    return t.to(device)

def prep_for_clip(img):
    return clip_preprocess(img).unsqueeze(0).to(device)

# Process each experiment
for base in base_dirs:
    # Prepare output CSV
    csv_path = os.path.join(base, f"epoch_metrics_{os.path.basename(base)}.csv")
    with open(csv_path, 'w', newline='') as out_f:
        writer = csv.writer(out_f)
        writer.writerow(["epoch", "avg_clip_similarity", "avg_lpips_distance"])

        # Read source list from split.csv
        split_csv = os.path.join(base, 'split.csv')
        sources = []
        with open(split_csv, newline='') as in_f:
            reader = csv.reader(in_f)
            next(reader)  # skip header
            for row in reader:
                sources.append(row[0])

        # Iterate through epochs
        for epoch in range(1, 16):
            clip_sims = []
            lpips_dists = []
            epoch_dir = os.path.join(base, 'output_images', f'{epoch}')
            if not os.path.isdir(epoch_dir):
                continue

            for src in sources:
                base_name = os.path.splitext(src)[0]
                # original photo path
                real_path = os.path.join(photos_dir, src)
                if not os.path.isfile(real_path):
                    continue
                real_img = load_image(real_path)

                # generated variants
                for idx in [0, 1]:
                    gen_fname = f"{base_name}.jpg_{idx}.png"
                    gen_path = os.path.join(epoch_dir, gen_fname)
                    if not os.path.isfile(gen_path):
                        continue
                    gen_img = load_image(gen_path)

                    # CLIP similarity
                    with torch.no_grad():
                        r_emb = clip_model.encode_image(prep_for_clip(real_img))
                        g_emb = clip_model.encode_image(prep_for_clip(gen_img))
                        sim = torch.cosine_similarity(r_emb, g_emb).item()

                    # LPIPS distance
                    lp = lpips_fn(prep_for_lpips(real_img), prep_for_lpips(gen_img)).item()

                    clip_sims.append(sim)
                    lpips_dists.append(lp)

            # Compute and write averages
            avg_sim = sum(clip_sims) / len(clip_sims) if clip_sims else 0.0
            avg_lp = sum(lpips_dists) / len(lpips_dists) if lpips_dists else 0.0
            print(f"Epoch {epoch} | Avg CLIP Similarity: {avg_sim:.6f} | Avg LPIPS Distance: {avg_lp:.6f}")
            writer.writerow([epoch, f"{avg_sim:.6f}", f"{avg_lp:.6f}"])

    print(f"Wrote epoch metrics CSV for {os.path.basename(base)}: {csv_path}")


#Download Folder

In [None]:
!zip -r non_overlay8epoch15With8loramidblocks.zip non_overlay8epoch15With8loramidblocks

In [None]:
from google.colab import files

files.download("non_overlay8epoch15With8loramidblocks.zip")
!cp non_overlay8epoch15With8loramidblocks.zip "/content/drive/MyDrive/backups/non_overlay8epoch15With8loramidblocks.zip"


In [None]:
#unet.save_pretrained(f"{runtime}/best_unet_epoch{epoch+1}")
from peft import PeftModel

stable_diffusion=StableDiffusionPipeline.from_pretrained(
    stable_diffusion_1_5,
    torch_dtype=torch.float16,
    safety_checker=None
)
vae = stable_diffusion.vae.to(device)
unet = stable_diffusion.unet.to(device)
tokenizer = stable_diffusion.tokenizer
text_encoder = stable_diffusion.text_encoder.to(device)

#Load LoRA adapters
unet = PeftModel.from_pretrained(unet, f"non_overlay8epoch15With8loramidblocks.zip/best_unet_epoch{best_runtime}")


# Evaluate  Model


In [None]:
import importlib
import pipeline
import TrainingPipelineTest
importlib.reload(TrainingPipelineTest)
from TrainingPipelineTest import SketchGuidedText2ImageTrainer  # Import again


# Initialize trainer
eval = SketchGuidedText2ImageTrainer(
    stable_diffusion_pipeline=stable_diffusion,
    unet=unet,
    vae=vae,
    text_encoder=text_encoder,
    lep_unet=LEP_UNET,
    scheduler=noise_scheduler,
    tokenizer=tokenizer,
    sketch_simplifier=sketch_simplifier,
    device=device
)

In [None]:
import torch
import csv

# Evaluation CSV Logging
runtime = runtime
os.makedirs(f"/content/{runtime}", exist_ok=True)


# Validation phase
total_val_mse = 0.0
total_val_vgg = 0.0
total_val_white = 0.0
total_val_ssim = 0.0
total_val_edge = 0.0
total_samples = 0
unet.eval()

with torch.no_grad():  # No gradients computed for validation
  val_mse, val_ssim, val_edge_loss = 0, 0, 0
  total_epoch_val_loss = 0.0
  total_epoch_guidance_loss = 0.0
  total_samples = 0
  for batch in test_dataloader:
    images, sketches, text_prompts,filenames = (
          batch["image"].to(device),
          batch["sketch"],
          batch["text_prompt"],
          batch["filename"]
      )
    # Iterate over each image, its sketches, and text prompt
    for img, sketches, prompt, filename in zip(images, sketches, text_prompts, filenames):
        img = img.unsqueeze(0)  # Add batch dimension for single image
        losses_per_sketch = []
        guidance_losses_per_sketch = []
        mse_list, vgg_list, white_list, ssim_list, edge_list = [], [], [], [], []
        iteration=0
        # Evaluate all sketches individually
        for sketch in sketches:
            val_loss, guidance_loss, ssim_loss, edge_loss, whiteloss, vgg, generated_image, edgemap  = trainer.eval(
                images=img,
                sketches=[sketch],  # Pass single sketch as a list
                text_prompts=[prompt],
                optimizer=None,  # No optimizer step during eval
                noise_scheduler=noise_scheduler,
                num_inference_timesteps=num_inference_timesteps,
                classifier_guidance_strength=8,
                sketch_guidance_strength=1.6,
                guidance_steps_perc=0.5,
                seed=None
            )
            mse_list.append(val_loss.item())
            vgg_list.append(vgg.item())
            white_list.append(whiteloss.item())
            ssim_list.append(ssim_loss.item())
            edge_list.append(edge_loss.item())
            generated_image.save(f'output_images/sketch: {iteration} name: {filename}.png')
            sketches=sketches+1


        # Average losses per image (across sketches)
        avg_mse_img = sum(mse_list) / len(mse_list)
        avg_vgg_img = sum(vgg_list) / len(vgg_list)
        avg_white_img = sum(white_list) / len(white_list)
        avg_ssim_img = sum(ssim_list) / len(ssim_list)
        avg_edge_img = sum(edge_list) / len(edge_list)

        # Correct accumulation here:
        total_val_mse += avg_mse_img
        total_val_vgg += avg_vgg_img
        total_val_white += avg_white_img
        total_val_ssim += avg_ssim_img
        total_val_edge += avg_edge_img

        total_samples += 1

    # Correctly compute epoch averages:
    avg_val_mse = total_val_mse / total_samples
    avg_val_vgg = total_val_vgg / total_samples
    avg_val_white = total_val_white / total_samples
    avg_val_ssim = total_val_ssim / total_samples
    avg_val_edge = total_val_edge / total_samples
    # Compute total validation loss with dynamic weighting
    validation_total_loss = (
        avg_val_mse*msew +
        vggw * avg_val_vgg+ 0.1 * avg_val_white
    )

    #print(f"Epoch {epoch+1} | Total Validation Loss: {validation_total_loss:.4f} | MSE: {epoch_avg_val_loss:.4f} | SSIM: {epoch_avg_ssim_loss:.4f} | Edge Loss: {epoch_avg_edge_loss:.4f}")
    print(f"Epoch {epoch+1} | Total Validation Loss: {validation_total_loss:.4f} | "
    f"MSE: {avg_val_mse:.4f} |VGG Loss: {avg_val_vgg:.4f} | SSIM: {avg_val_ssim:.4f} | "
    f"Edge Loss: {avg_val_edge:.4f} ")


# Log epoch summary clearly
with open(f"/content/{runtime}/epoch_eval_summary.csv", "w", newline='') as epoch_summary_file:
    summary_writer = csv.writer(epoch_summary_file)
    summary_writer.writerow(["Epoch", "MSE Loss", "VGG Loss", "White Loss", "Total Loss", "SSIM", "Edge"])
    summary_writer.writerow([1,avg_val_mse, avg_val_vgg, avg_val_white,
            validation_total_loss, avg_val_ssim, avg_val_edge
        ])


print(f"🎯 Testing completed ")



In [None]:
!zip -r overlay10epoch15With8loramidblocks_2.zip overlay10epoch15With8loramidblocks

In [None]:
from google.colab import files
files.download("overlay10epoch15With8loramidblocks.zip")

In [None]:
!cp overlay10epoch15With8loramidblocks.zip "/content/drive/MyDrive/backups/overlay10epoch15With8loramidblocks.zip"

# Evualating the "best Model"


In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Subset, DataLoader
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset

class SketchImageDataset(Dataset):
    def __init__(self, root_dir, mapping_file=None, image_transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, "photos")
        self.sketch_dir = os.path.join(root_dir, "sketch")
        # collect all image filenames
        self.image_filenames = [
            f for f in os.listdir(self.image_dir)
            if f.lower().endswith((".jpg", ".png"))
        ]
        self.image_transform = image_transform

        # load mapping file if exists
        self.mapping = {}
        if mapping_file is not None and os.path.exists(mapping_file):
            with open(mapping_file, 'r') as f:
                for line in f:
                    line = line.strip()
                    if not line or ":" not in line:
                        continue
                    key, value = line.split(":", 1)
                    self.mapping[key.strip()] = value.strip()

    def clean_filename(self, filename):
        return os.path.splitext(filename)[0]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_name = self.image_filenames[idx]
        # load photo
        image = Image.open(os.path.join(self.image_dir, image_name)).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)

        base_name = self.clean_filename(image_name)
        sketch_folder = os.path.join(self.sketch_dir, base_name)
        if not os.path.exists(sketch_folder):
            raise FileNotFoundError(f"Sketch folder not found: {sketch_folder}")

        # pick any sketch files matching your stroke criteria
        sketch_files = [
            f for f in os.listdir(sketch_folder)
            if f.lower().endswith((".jpg", ".png"))
           #and ("sketchs32strokes" in f or "sketchs20strokes" in f)
        ]
        if not sketch_files:
            raise ValueError(f"No sketches with 'sketchs32strokes' found for {image_name}")

        sketches = [
            Image.open(os.path.join(sketch_folder, f)).convert("RGB")
            for f in sketch_files
        ]

        prompt_word = self.mapping.get(base_name, base_name)
        text_prompt = f"Lego {prompt_word} with an empty background/surrounding"

        return {
            "image": image,
            "sketches": sketches,
            "text_prompt": text_prompt,
            "filename": image_name
        }


def custom_collate_fn(batch):
    images = torch.stack([item["image"] for item in batch])
    sketches = [item["sketches"] for item in batch]
    text_prompts = [item["text_prompt"] for item in batch]
    filenames = [item["filename"] for item in batch]
    return {
        "image": images,
        "sketch": sketches,
        "text_prompt": text_prompts,
        "filename": filenames
    }


# --- User parameters ---
sketchorgigin="non_overlay"
root_dir       = f"/content/drive/MyDrive/Lego_256x256/{sketchorgigin}"
mapping_file   = "mappings.txt"

#Set the test set
chosen_folder  = "clipasso"   # where split.csv lives
split_csv_path = os.path.join(chosen_folder, "split.csv")
image_size     = 512

# transforms
photo_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# init full dataset
dataset = SketchImageDataset(
    root_dir=root_dir,
    mapping_file=mapping_file,
    image_transform=photo_transform
)

# read split assignments
df_split = pd.read_csv(split_csv_path)  # expects columns: filename, split
# make sure filenames match exactly those in dataset.image_filenames
# e.g. if df has full paths, strip them down:
df_split.columns = df_split.columns.str.lower()
df_split['filename'] = df_split['filename'].apply(os.path.basename)

# build index lists
split_to_indices = {'train': [], 'val': [], 'test': []}
for idx, fname in enumerate(dataset.image_filenames):
    # find this filename in df_split
    row = df_split[df_split['filename'] == fname]
    if len(row) == 0:
        continue  # or raise if you want strict matching
    split_name = row.iloc[0]['split']
    if split_name in split_to_indices:
        split_to_indices[split_name].append(idx)

# create Subsets
train_dataset = Subset(dataset, split_to_indices['train'])
val_dataset   = Subset(dataset, split_to_indices['val'])
test_dataset  = Subset(dataset, split_to_indices['test'])


test_loader = DataLoader(
    test_dataset, batch_size=1,
    shuffle=False, collate_fn=custom_collate_fn
)

print(f"Sizes → train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")


In [None]:
#unet.save_pretrained(f"{runtime}/best_unet_epoch{epoch+1}")
from peft import PeftModel
stable_diffusion_1_5 = "benjamin-paine/stable-diffusion-v1-5"

stable_diffusion=StableDiffusionPipeline.from_pretrained(
    stable_diffusion_1_5,
    torch_dtype=torch.float16,
    safety_checker=None
)
vae = stable_diffusion.vae.to(device)
unet = stable_diffusion.unet.to(device)
tokenizer = stable_diffusion.tokenizer
text_encoder = stable_diffusion.text_encoder.to(device)

#Load LoRA adapters
#unet = PeftModel.from_pretrained(unet, f"unetoverlay")

In [None]:
# Noise Scheduler
noise_scheduler = DDIMScheduler(
        beta_start = 0.00085,
        beta_end = 0.012,
        beta_schedule = "scaled_linear",
        num_train_timesteps = 1000,
        clip_sample = False,
    )


In [None]:
import importlib
import pipeline
import TrainingPipelineTest
importlib.reload(TrainingPipelineTest)
from TrainingPipelineTest import SketchGuidedText2ImageTrainer  # Import again

# Initialize trainer
eval = SketchGuidedText2ImageTrainer(
    stable_diffusion_pipeline=stable_diffusion,
    unet=unet,
    vae=vae,
    text_encoder=text_encoder,
    lep_unet=LEP_UNET,
    scheduler=noise_scheduler,
    tokenizer=tokenizer,
    sketch_simplifier=sketch_simplifier,
    device=device
)


In [None]:
# Modified validation loop to print and log all metrics per sketch and per image

import os
import csv
import torch
runtime = chosen_folder
detail_loss_csv = os.path.join(runtime, f"{chosen_folder}_eval_base.csv")
#detail_loss_per_image = os.path.join(runtime, f"{chosen_folder}_image_loss.csv")

if not os.path.exists(detail_loss_csv):
  with open(detail_loss_csv, "w", newline="") as f:
      csv.writer(f).writerow([ "Test Set","MSE Loss", "VGG Loss", "White Loss", "Total Loss", "SSIM", "Edge"])
"""
if not os.path.exists(detail_loss_per_image):
  with open(detail_loss_per_image, "w", newline="") as f:
      csv.writer(f).writerow(["Test set","Filename", "MSE Loss", "VGG Loss", "White Loss", "SSIM", "Edge", "Total Loss"])
"""
# Inside your validation block:
ssimw, edgew, msew, vggw = 0.5, 0.5, 1.0, 0.35

with torch.no_grad():  # No gradients computed for validation
  val_mse, val_ssim, val_edge_loss = 0, 0, 0
  total_val_total_loss = 0.0
  total_epoch_val_loss = 0.0
  total_samples
  total_val_mse = 0.0
  total_val_vgg = 0.0
  total_val_white = 0.0
  total_val_ssim = 0.0
  total_val_edge= 0.0
  total_val_total_loss = 0.0
  total_samples = 0
  for batch in test_loader:
          images, sketches, text_prompts,filenames = (
                batch["image"].to(device),
                batch["sketch"],
                batch["text_prompt"],
                batch["filename"]
            )
          # Iterate over each image, its sketches, and text prompt
          for img, sketches, prompt, filename in zip(images, sketches, text_prompts, filenames):
                img = img.unsqueeze(0)  # Add batch dimension for single image
                losses_per_sketch = []
                guidance_losses_per_sketch = []
                mse_list, vgg_list, white_list, ssim_list, edge_list = [], [], [], [], []
                iteration=0
                # Evaluate all sk

                for sketch in sketches:
                    val_loss, guidance_loss, ssim_loss, edge_loss, whiteloss, vgg, generated_image, edgemap  = eval.eval(
                        images=img,
                        sketches=[sketch],  # Pass single sketch as a list
                        text_prompts=[prompt],
                        optimizer=None,  # No optimizer step during eval
                        noise_scheduler=noise_scheduler,
                        num_inference_timesteps=50,
                        classifier_guidance_strength=8,
                        sketch_guidance_strength=1.6,
                        guidance_steps_perc=0.5,
                        seed=None
                    )
                    mse_list.append(val_loss.item())
                    vgg_list.append(vgg.item())
                    white_list.append(whiteloss.item())
                    ssim_list.append(ssim_loss.item())
                    edge_list.append(edge_loss.item())
                    if not os.path.exists(f'{runtime}/output_images_base/{sketchorgigin}'):
                        os.makedirs(f'{runtime}/output_images_base/{sketchorgigin}')
                    epoch_dir = f"{runtime}/output_images_base/{sketchorgigin}"
                    generated_image.save(f'{epoch_dir}/{filename}{iteration}.png')
                    iteration += 1


                # Average losses per image (across sketches)
                avg_mse_img = sum(mse_list) / len(mse_list)
                avg_vgg_img = sum(vgg_list) / len(vgg_list)
                avg_white_img = sum(white_list) / len(white_list)
                avg_ssim_img = sum(ssim_list) / len(ssim_list)
                avg_edge_img = sum(edge_list) / len(edge_list)
                avg_val_total= (
                    avg_mse_img*msew +
                    vggw * avg_vgg_img+ 0.35
                )
                """
                with open(detail_loss_per_image, "a", newline="") as f:
                  writer = csv.writer(f)
                  writer.writerow([sketchorgigin,
                      filename,avg_mse_img, avg_vgg_img, avg_white_img,
                      avg_ssim_img, avg_edge_img, avg_val_total
                  ])
                """
                # Correct accumulation here:
                total_val_mse += avg_mse_img
                total_val_vgg += avg_vgg_img
                total_val_white += avg_white_img
                total_val_ssim += avg_ssim_img
                total_val_edge += avg_edge_img
                total_val_total_loss += avg_val_total

                total_samples += 1

  # Compute epoch averages
  avg_val_mse = total_val_mse / total_samples
  avg_val_vgg = total_val_vgg / total_samples
  avg_val_white = total_val_white / total_samples
  avg_val_ssim = total_val_ssim / total_samples
  avg_val_edge = total_val_edge / total_samples
  avg_val_total = total_val_total_loss / total_samples

  print(f"{sketchorgigin}",f" MSE: {avg_val_mse:.4f} | VGG: {avg_val_vgg:.4f} | "
        f"White: {avg_val_white:.4f} | Edge: {avg_val_edge:.4f} | SSIM: {avg_val_ssim:.4f} | "
        f"Total: {avg_val_total:.4f}")

  # Log to CSV
  with open(detail_loss_csv, "a", newline="") as f:
      writer = csv.writer(f)
      writer.writerow([sketchorgigin, avg_val_mse, avg_val_vgg, avg_val_white, avg_val_total, avg_val_ssim, avg_val_edge])


In [None]:
!zip -r non_overlay.zip non_overlay

In [None]:
!zip -r overlay.zip overlay

In [None]:
!zip -r clipasso.zip clipasso

In [None]:
from google.colab import files
files.download(non_overlay.zip)