# General Setup

# Class

In [1]:
from src.movies import get_movie_script
from src.storyboard_generator import StoryboardGenerator
import torch 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
script, characters = get_movie_script("godfather")

In [4]:
print(script)


INT DAY: DON'S OFFICE (SUMMER 1945)

				DON CORLEONE
		ACT LIKE A MAN!  By Christ in
		Heaven, is it possible you turned
		out no better than a Hollywood
		finocchio.

	Both HAGEN and JOHNNY cannot refrain from laughing.  The DON
	smiles.  SONNY enters as noiselessly as possible, still
	adjusting his clothes.

				DON CORLEONE
		All right, Hollywood...Now tell me
		about this Hollywood Pezzonovanta
		who won't let you work.

				JOHNNY
		He owns the studio.  Just a month
		ago he bought the movie rights to
		this book, a best seller.  And the
		main character is a guy just like
		me.  I wouldn't even have to act,
		just be myself.

	The DON is silent, stern.

				DON CORLEONE
		You take care of your family?

				JOHNNY
		Sure.

	He glances at SONNY, who makes himself as inconspicuous as
	he can.

				DON CORLEONE
		You look terrible.  I want you to
		eat well, to rest.  And spend time
		with your family.  And then, at the
		end of the month, this big shot
		will give you the part you

In [None]:
generator = StoryboardGenerator(script, characters_dict, device=device)

In [None]:
generator.generate_and_save(save_dir="unique", generation_type="unique")

In [None]:
generator.generate_and_save(save_dir="prompt_weights", generation_type="prompt_weights")

In [None]:
generator.generate_and_save(save_dir="modified-cfg", generation_type="modified-cfg")

# Some Comments on Modified Classifier-Free Guidance

## Single Unconditional Pass + Multiple Conditional Passes

Let $\hat{\epsilon}_{\text{cond\_combined}}=\frac{1}{\sum_{i=1}^nw_i}\sum_{i=1}^nw_i\hat{\epsilon}_{\text{cond}_i}$
where we have one pass per subprompt to get $\hat{\epsilon}_{\text{cond}_i}$ and $n$ is the number of subprompts.
Then the classifier free guidance with scale $g$ is $$\hat{\epsilon}=\hat{\epsilon}_{\text{uncond}}+g(\hat{\epsilon}_{\text{cond\_combined}}-\hat{\epsilon}_{\text{uncond}})$$
where we have one unconditional pass at each step to get $\hat{\epsilon}_{\text{uncond}}$

- Total UNet calls per step: $1+n$
- Each subprompt has a relative weight but they all share the same baseline unconditional pass

## Multiple Unconditional Passes (One per Subprompt)

We have $$\hat{\epsilon}=\hat{\epsilon}_{\text{uncond}}+g\sum_{i=1}^nw_i(\hat{\epsilon}_{\text{cond}_i}-\hat{\epsilon}_{\text{uncond}_i})$$
- Total UNet calls per step: $1+2n$ (One global unconditional + two passes for each subprompt)

In [None]:
# class MultiPromptPipelineApproach2(StableDiffusionPipeline):
#     """
#     Multi-Prompt CFG with MULTIPLE unconditional passes:
#       - 1 global unconditional pass per step: e_uncond
#       - For each subprompt i:
#           e_uncond_i (subprompt-specific unconditional)
#           e_cond_i    (subprompt conditional)
#       - Combine: e = e_uncond + g * sum_i[ w_i * ( e_cond_i - e_uncond_i ) ]
#     """

#     @torch.no_grad()
#     def __call__(
#         self,
#         global_uncond_embeds: torch.Tensor,
#         subprompt_pairs: list[tuple[torch.Tensor, torch.Tensor]],
#         subprompt_weights: list[float],
#         guidance_scale: float = 7.5,
#         height: int = 512,
#         width: int = 512,
#         num_inference_steps: int = 50,
#         generator: torch.Generator = None,
#         latents: torch.Tensor = None,
#         output_type: str = "pil",
#         return_dict: bool = True,
#         **kwargs
#     ):
#         """
#         Args:
#             global_uncond_embeds (Tensor): [batch, seq_len, hidden_dim] for the entire prompt's unconditional pass.
#             subprompt_pairs (list of (uncond_i, cond_i)):
#                 Each element is a tuple: (uncond_embeds_i, cond_embeds_i).
#             subprompt_weights (list[float]): Weights w_i for each subprompt i.
#         """
#         device = self._execution_device
#         batch_size = global_uncond_embeds.shape[0]
#         num_subprompts = len(subprompt_pairs)

#         if num_subprompts != len(subprompt_weights):
#             raise ValueError("subprompt_pairs and subprompt_weights must have the same length.")

#         # 1. Validate or fallback to default
#         if not height or not width:
#             height, width = self._default_height_width()

#         # 2. Scheduler timesteps
#         self.scheduler.set_timesteps(num_inference_steps, device=device)
#         timesteps = self.scheduler.timesteps

#         # 3. Prepare latents
#         if latents is None:
#             shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
#             latents = torch.randn(shape, generator=generator, device=device, dtype=global_uncond_embeds.dtype)
#             latents = latents * self.scheduler.init_noise_sigma
#         else:
#             latents = latents.to(device)

#         # 4. Diffusion loop
#         for i, t in enumerate(timesteps):
#             latent_model_input = self.scheduler.scale_model_input(latents, t)

#             # (A) Single global unconditional pass
#             e_uncond_global = self.unet(
#                 latent_model_input, t, encoder_hidden_states=global_uncond_embeds, **kwargs
#             ).sample

#             # (B) For each subprompt: unconditional + conditional
#             sub_deltas = []
#             for (uncond_i, cond_i), w in zip(subprompt_pairs, subprompt_weights):
#                 e_uncond_i = self.unet(latent_model_input, t, encoder_hidden_states=uncond_i, **kwargs).sample
#                 e_cond_i = self.unet(latent_model_input, t, encoder_hidden_states=cond_i, **kwargs).sample

#                 # Delta for subprompt i
#                 delta_i = w * (e_cond_i - e_uncond_i)
#                 sub_deltas.append(delta_i)

#             # (C) Combine sub-deltas
#             sum_deltas = sum(sub_deltas)  # sum_i w_i ( e_cond_i - e_uncond_i )

#             # (D) Final output
#             guided_out = e_uncond_global + guidance_scale * sum_deltas

#             # (E) Scheduler step
#             latents = self.scheduler.step(guided_out, t, latents, **kwargs).prev_sample

#         # 5. Decode
#         if output_type == "latent":
#             if return_dict:
#                 from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
#                 return StableDiffusionPipelineOutput(images=latents, nsfw_content_detected=None)
#             return latents

#         image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
#         image = self.image_processor.postprocess(image, output_type=output_type)

#         if return_dict:
#             from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
#             return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
#         return image

# print("Loading Approach 2 pipeline...")
# pipe2 = MultiPromptPipelineApproach2.from_pretrained(
#     "runwayml/stable-diffusion-v1-5",
#     torch_dtype=torch.float16
# ).to("cuda")
# pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
# pipe2.enable_model_cpu_offload()
# pipe2.enable_attention_slicing()

#### EXAMPLE
# # Suppose we want environment and style separately
# global_uncond = encode_subprompt(pipe2, "")  # global unconditional
# env_uncond = encode_subprompt(pipe2, "")     # unconditional for environment
# env_cond   = encode_subprompt(pipe2, "ancient forest, misty atmosphere")
# style_uncond = encode_subprompt(pipe2, "")   # unconditional for style
# style_cond   = encode_subprompt(pipe2, "cinematic style, high contrast")

# # subprompt_pairs = [ (uncond_env, cond_env), (uncond_style, cond_style) ]
# subprompt_pairs_2 = [
#     (env_uncond, env_cond),
#     (style_uncond, style_cond)
# ]

# weights_2 = [1.5, 1.8]
# print("Generating image with Approach 2 (multiple unconditional passes)...")
# output2 = pipe2(
#     global_uncond_embeds=global_uncond,
#     subprompt_pairs=subprompt_pairs_2,
#     subprompt_weights=weights_2,
#     guidance_scale=7.5,
#     num_inference_steps=25
# )
# output2.images[0].save("approach2_result.png")
# print("Saved approach2_result.png")

# def generate_and_save_images_multi_prompt2(scenes, characters_dict, pipe, save_dir, device,
#                                              num_inference_steps=50, guidance_scale=7.5):
#     """
#     Generate images for each scene using Multi-Prompt Approach 2 (multiple unconditional passes)
#     and save each image to the specified directory.
    
#     Args:
#         scenes (list): List of scene objects (each scene is a dict).
#         characters_dict (dict): Dictionary of character descriptions.
#         pipe: The MultiPromptPipelineApproach2 pipeline instance.
#         save_dir (str): Directory where images will be saved.
#         device (str): Device to use (e.g., "cuda" or "cpu").
#         num_inference_steps (int, optional): Number of diffusion steps.
#         guidance_scale (float, optional): Guidance scale for classifier-free guidance.
        
#     Returns:
#         list: List of generated PIL.Image objects.
#     """
#     import os
#     import torch

#     os.makedirs(save_dir, exist_ok=True)
#     generated_images = []

#     for i, scene in enumerate(scenes):
#         # Get subprompt texts and corresponding weights for the scene.
#         subprompt_texts, subprompt_weights = scenes_to_formatted_prompts([scene], characters_dict)[0]

#         # Encode the global unconditional prompt once.
#         global_uncond_embeds = encode_subprompt(pipe, "", device=device)

#         # For each subprompt, encode a pair: (unconditional, conditional)
#         subprompt_pairs = []
#         for sp in subprompt_texts:
#             uncond_i = encode_subprompt(pipe, "", device=device)
#             cond_i = encode_subprompt(pipe, sp, device=device)
#             subprompt_pairs.append((uncond_i, cond_i))

#         print(f"Generating image for scene {i+1} using Approach 2...")
#         with torch.no_grad():
#             output = pipe(
#                 global_uncond_embeds=global_uncond_embeds,
#                 subprompt_pairs=subprompt_pairs,
#                 subprompt_weights=subprompt_weights,
#                 guidance_scale=guidance_scale,
#                 num_inference_steps=num_inference_steps
#             )
#         generated_image = output.images[0]
#         generated_images.append(generated_image)
#         image_path = os.path.join(save_dir, f"scene_{i+1}_approach2.png")
#         generated_image.save(image_path)
#         print(f"Image {i+1} saved to {image_path}")

#     return generated_images

# # Example usage:
# save_directory = "stories/multi_prompt_approach2"
# generated_images = generate_and_save_images_multi_prompt2(scenes, characters_dict, pipe2, save_directory, device)

Stopped this because it's extremely slow (20 min for one image) and it's not good either.