# Create Synthetic Dataset

In [None]:
import numpy as np

import os
import matplotlib.pyplot as plt

import pickle
import torch

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

from probe_src.vis_partially_denoised_latents import generate_image, _init_models
from tqdm.auto import tqdm

## 1. Load in the Stable Diffusion Model

In [None]:
vae_pretrained="CompVis/stable-diffusion-v1-4"
CLIPtokenizer_pretrained="openai/clip-vit-large-patch14"
CLIPtext_encoder_pretrained="openai/clip-vit-large-patch14"
denoise_unet_pretrained="CompVis/stable-diffusion-v1-4"

vae, tokenizer, text_encoder, unet, scheduler = _init_models(vae_pretrained=vae_pretrained,
                                                             CLIPtokenizer_pretrained=CLIPtokenizer_pretrained,
                                                             CLIPtext_encoder_pretrained=CLIPtext_encoder_pretrained,
                                                             denoise_unet_pretrained=denoise_unet_pretrained)

## 2. Load in prompts and random seed for synthesizing images

The prompts are sampled from a partition of LAION 2B 5+ dataset. The seed for synthesizing images are randomly sampled between 0 and 1e8. 

The LAION 2B 5+ dataset contains:
1. The URL to the captioned image
2. The caption of the image
3. The original spatial dimension of the image
4. The image's aesthetic score 

For reproducibility, we provide the prompts and random seeds use in generating this dataset in "test_split_indices.csv" and "train_split_indices.csv"

In [None]:
train_split_prompts_seeds = pd.read_csv("train_split_prompts_seeds.csv")
display(train_split_prompts_seeds.head())

test_split_prompts_seeds = pd.read_csv("test_split_prompts_seeds.csv")
display(test_split_prompts_seeds.head())

## 3. Synthesize Images

In [None]:
data_path = "datasets"
dataset = "images"

dataset_path = os.path.join(data_path, dataset)
# If the dataset path not exists, create the path 
if not os.path.exists(dataset_path):
    os.makedirs(dataset_path)
    
combo_df = pd.concat([train_split_prompts_seeds, test_split_prompts_seeds])

for i in tqdm(range(len(combo_df))):
    prompt = combo_df.iloc[i, 0]
    seed_num = combo_df.iloc[i, 1]
    prompt_ind = combo_df.iloc[i, 2]
    
    image = generate_image(prompt, seed_num, 
                           net=unet, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=scheduler, vae=vae,
                           num_inference_steps=15, 
                           guidance_scale=7.5,
                           height=512, width=512)
    
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    image = (image * 255).round().astype("uint8")[0]
    
    plt.imsave(os.path.join(dataset_path, f"prompt_{prompt_ind}_seed_{seed_num}.png"), image)

## 4. Install Tracer for synthesizing salient object label

### Create synthetic dataset for probing the object detection

Before running the code below, first install Tracer from [https://github.com/Karel911/TRACER](https://github.com/Karel911/TRACER).

We assume your directory structure look like this

.---- ldm_depth/create_the_synthetic_dataset.ipynb  
|  
|  
|  
.---- Tracer/main.py   

Parameters used in synthesis:  
1. data_path: path prefix appended to your dataset path
2. dataset: folder name of your image dataset
3. arch: EfficientNet Backbone see [https://github.com/Karel911/TRACER](https://github.com/Karel911/TRACER) for more details
4. img_size: the size of the input image
5. save_map: whether to save the output object map

The resulting salient object label can be found in the directory
[ldm_depth/mask/images/](ldm_depth/mask/images/)

In [None]:
!python ../TRACER/main.py inference --data_path datasets/ --dataset images/ --arch 5 --img_size 512 --save_map True

## 5. Load in the MiDaS Model for depth estimation

Github link to the MiDaS model: [https://github.com/isl-org/MiDaS](https://github.com/isl-org/MiDaS)

In [None]:
model_type = "DPT_Large"     # MiDaS v3 - Large     (highest accuracy, slowest inference speed)

# Load in the DPT large model
midas = torch.hub.load("intel-isl/MiDaS", model_type)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
midas.to(device)
midas.eval();

# Initiate the input transformation
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
    transform = midas_transforms.dpt_transform
else:
    transform = midas_transforms.small_transform

## 6. Create the Synthetic Labels for Depth Estimation

In [None]:
data_path = "datasets"
dataset = "images"

depth_label_path = os.path.join(data_path, "depth_gt")
# If the dataset path not exists, create the path 
if not os.path.exists(depth_label_path):
    os.makedirs(depth_label_path)

dataset_path = os.path.join(data_path, dataset)

image_filenames = os.listdir(dataset_path)
image_filenames = [filename for filename in image_filenames if filename.endswith(".png")]

for filename in image_filenames:
    img = plt.imread(os.path.join(dataset_path, f"{filename}"))[...,:3]
    if img.max() <= 1:
        img *= 255
        img = img.astype("uint8")
    
    input_batch = transform(img).to(device)
    
    with torch.no_grad():
        prediction = midas(input_batch)

        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=img.shape[:2],
            mode="bicubic",
            align_corners=False,
        ).squeeze()
    
    # Save the predicted depth map
    with open(os.path.join(depth_label_path, filename[:-4] + ".pkl"), "wb") as outfile:
        pickle.dump(prediction.cpu().numpy(), outfile)