# Run Consistory Batch

In [None]:
# Copyright (C) 2024 NVIDIA Corporation.  All rights reserved.
#
# This work is licensed under the LICENSE file
# located at the root directory.

import torch
import gc
from consistory_run import load_pipeline, run_batch_generation

gpu = 0
story_pipeline = load_pipeline(gpu)

style = "A photo of "
subject = "a cute dog"
concept_token = ['dog']
settings = ["sitting in the beach",
            "standing in the snow",
            "playing in the park"]

seed = 40
mask_dropout = 0.5
same_latent = False
n_achors = 2

prompts = [f'{style}{subject} {setting}' for setting in settings]

# Reset the GPU memory tracking
torch.cuda.reset_max_memory_allocated(gpu)

images, image_all = run_batch_generation(story_pipeline, prompts, concept_token, seed, mask_dropout=mask_dropout, same_latent=same_latent, n_achors = 2)
display(image_all)

# Report maximum GPU memory usage in GB
max_memory_used = torch.cuda.max_memory_allocated(gpu) / (1024**3)  # Convert to GB
print(f"Maximum GPU memory used: {max_memory_used:.2f} GB")

# Run Consistory w/ Cached Anchors

In [None]:
import torch
import gc
from consistory_run import load_pipeline, run_anchor_generation, run_extra_generation

gpu = 0
story_pipeline = load_pipeline(gpu)

style = "A photo of "
subject = "a cute dog"
concept_token = ['dog']
anchor_settings = ["sitting in the beach", "standing in the snow"]
extra_settings = ["playing in the park", "surfing in the ocean"]

seed = 40
mask_dropout = 0.5
same_latent = False

anchor_prompts = [f'{style}{subject} {setting}' for setting in anchor_settings]
extra_prompts = [f'{style}{subject} {setting}' for setting in extra_settings]

# Reset the GPU memory tracking
torch.cuda.reset_max_memory_allocated(gpu)

anchor_out_images, anchor_image_all, anchor_cache_first_stage, anchor_cache_second_stage = run_anchor_generation(story_pipeline, anchor_prompts, concept_token, 
                                                                                                       seed=seed, mask_dropout=mask_dropout, same_latent=same_latent,
                                                                                                       cache_cpu_offloading=True)

print('Anchor images:')
display(anchor_image_all)

for extra_prompt in extra_prompts:
    extra_out_images, extra_image_all = run_extra_generation(story_pipeline, [extra_prompt], concept_token, anchor_cache_first_stage, anchor_cache_second_stage, 
                                                seed=seed, mask_dropout=mask_dropout, same_latent=same_latent, cache_cpu_offloading=True)
    
    print(f'Extra prompt: {extra_prompt}')
    display(extra_image_all)

# Report maximum GPU memory usage in GB
max_memory_used = torch.cuda.max_memory_allocated(gpu) / (1024**3)  # Convert to GB
print(f"Maximum GPU memory used: {max_memory_used:.2f} GB")