In [None]:
# Standard imports
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image

# Hugging Face Hub import

# Diffusers-specific imports
from diffusers import StableDiffusionPipeline, DDIMScheduler

# Custom modules
from models import UNETLatentEdgePredictor, SketchSimplificationNetwork
from pipeline import SketchGuidedText2Image



In [1]:
import os
from collections import defaultdict

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Loading data 

In [8]:

# Paths to your dataset directories
photo_dir = "./256x256/photo/tx_000000000000"

# Initialize dictionaries to store datasets
photo_dataset = defaultdict(list)

for root, _, files in os.walk(photo_dir):
    for file in files:
        if file.endswith(".jpg"):
            # Extract object and ID from the file path
            relative_path = os.path.relpath(os.path.join(root, file), photo_dir)
            parts = relative_path.split(os.sep)
            if len(parts) >= 2:
                object_name = parts[0]  # Example: 'airplane'
                id_name = os.path.join(object_name, os.path.splitext(file)[0])  # Example: 'airplane\n02691156_10151'

                # Add the photo to the photo dataset
                photo_dataset[object_name].append(id_name)



In [9]:
# Print dataset structure for verification
for object_name, ids in photo_dataset.items():
    print(f"  Object: {object_name}")
    for id_name in ids:
        print(f"    Image: {id_name}")


  Object: airplane
    Image: airplane\n02691156_10151
    Image: airplane\n02691156_10153
    Image: airplane\n02691156_10168
    Image: airplane\n02691156_10381
    Image: airplane\n02691156_10391
    Image: airplane\n02691156_10433
    Image: airplane\n02691156_10504
    Image: airplane\n02691156_10535
    Image: airplane\n02691156_10578
    Image: airplane\n02691156_10603
    Image: airplane\n02691156_10622
    Image: airplane\n02691156_10718
    Image: airplane\n02691156_1074
    Image: airplane\n02691156_10932
    Image: airplane\n02691156_11016
    Image: airplane\n02691156_11257
    Image: airplane\n02691156_11286
    Image: airplane\n02691156_11379
    Image: airplane\n02691156_1142
    Image: airplane\n02691156_11495
    Image: airplane\n02691156_1182
    Image: airplane\n02691156_12022
    Image: airplane\n02691156_12181
    Image: airplane\n02691156_12619
    Image: airplane\n02691156_1270
    Image: airplane\n02691156_14875
    Image: airplane\n02691156_14912
    Image: ai

In [10]:
sketch_dir = "./256x256/sketch/tx_000000000000"
sketch_dataset = defaultdict(list)
# Process the sketch dataset
for root, _, files in os.walk(sketch_dir):
    for file in files:
        if file.endswith(".png"):
            # Extract object and ID from the file path
            relative_path = os.path.relpath(os.path.join(root, file), sketch_dir)
            parts = relative_path.split(os.sep)
            if len(parts) >= 2:
                object_name = parts[0]  # Example: 'airplane'
                id_name = os.path.join(object_name, os.path.splitext(file)[0])  # Append '-1'

                # Add the sketch to the sketch dataset
                sketch_dataset[object_name].append(id_name)

In [11]:
for object_name, ids in sketch_dataset.items():
    print(f"  Object: {object_name}")
    for id_name in ids:
        print(f"    Sketch: {id_name}")

  Object: airplane
    Sketch: airplane\n02691156_10151-1
    Sketch: airplane\n02691156_10151-2
    Sketch: airplane\n02691156_10151-3
    Sketch: airplane\n02691156_10151-4
    Sketch: airplane\n02691156_10151-5
    Sketch: airplane\n02691156_10151-6
    Sketch: airplane\n02691156_10151-7
    Sketch: airplane\n02691156_10151-8
    Sketch: airplane\n02691156_10153-1
    Sketch: airplane\n02691156_10153-2
    Sketch: airplane\n02691156_10153-3
    Sketch: airplane\n02691156_10153-4
    Sketch: airplane\n02691156_10153-5
    Sketch: airplane\n02691156_10153-6
    Sketch: airplane\n02691156_10153-7
    Sketch: airplane\n02691156_10168-1
    Sketch: airplane\n02691156_10168-2
    Sketch: airplane\n02691156_10168-3
    Sketch: airplane\n02691156_10168-4
    Sketch: airplane\n02691156_10168-5
    Sketch: airplane\n02691156_10381-1
    Sketch: airplane\n02691156_10381-2
    Sketch: airplane\n02691156_10381-3
    Sketch: airplane\n02691156_10381-4
    Sketch: airplane\n02691156_10381-5
    Sk

In [12]:
textprompts=[]
for object_name, ids in sketch_dataset.items():
    textprompts.append(object_name)

In [13]:
print(textprompts)

['airplane', 'alarm_clock', 'ant', 'ape', 'apple', 'armor', 'axe', 'banana', 'bat', 'bear', 'bee', 'beetle', 'bell', 'bench', 'bicycle', 'blimp', 'bread', 'butterfly', 'cabin', 'camel', 'candle', 'cannon', 'car_(sedan)', 'castle', 'cat', 'chair', 'chicken', 'church', 'couch', 'cow', 'crab', 'crocodilian', 'cup', 'deer', 'dog', 'dolphin', 'door', 'duck', 'elephant', 'eyeglasses', 'fan', 'fish', 'flower', 'frog', 'geyser', 'giraffe', 'guitar', 'hamburger', 'hammer', 'harp', 'hat', 'hedgehog', 'helicopter', 'hermit_crab', 'horse', 'hot-air_balloon', 'hotdog', 'hourglass', 'jack-o-lantern', 'jellyfish', 'kangaroo', 'knife', 'lion', 'lizard', 'lobster', 'motorcycle', 'mouse', 'mushroom', 'owl', 'parrot', 'pear', 'penguin', 'piano', 'pickup_truck', 'pig', 'pineapple', 'pistol', 'pizza', 'pretzel', 'rabbit', 'raccoon', 'racket', 'ray', 'rhinoceros', 'rifle', 'rocket', 'sailboat', 'saw', 'saxophone', 'scissors', 'scorpion', 'seagull', 'seal', 'sea_turtle', 'shark', 'sheep', 'shoe', 'skyscraper

# pre proccesing


# Image to vector


In [None]:
stable_diffusion_1_5 = "benjamin-paine/stable-diffusion-v1-5"


In [None]:

vae = stable_diffusion.vae.to(device)
unet = stable_diffusion.unet.to(device)
tokenizer = stable_diffusion.tokenizer
text_encoder = stable_diffusion.text_encoder.to(device) 

vae.eval()
unet.eval()
text_encoder.eval()

text_encoder.requires_grad_(False)
unet.requires_grad_(False)

# training 