In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm
import time
from helper import CustomDiffusion, DiffusionResult
import plotly.express as px
from IPython import display

HF_HOME = '../.hf_home'
RANDOM_SEED = 42
SLIDER_SPEED=50
LINEAR_INTERPOLATION_STEPS = 50
FIRST_PROMPT="A dog"
SECOND_PROMPT="A cat"

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
os.environ['HF_HOME'] = HF_HOME
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

diff = CustomDiffusion(device=device)

In [None]:
def visualize_diffusion(current_step, total_steps, image):
    display.clear_output(wait=True)
    print(f"Step {current_step+1}/{total_steps}")
    display.display(image)

In [None]:
first_result = diff.generate(
    FIRST_PROMPT,
    print_steps=False,
    decode_every_step=True,
    seed=RANDOM_SEED,
    callback_fn=visualize_diffusion,
    callback_args=["current_step", "total_steps", "image"]
)

In [None]:
fig = px.imshow(np.array(first_result.image_list), animation_frame=0, width=800, height=600)
fig.update_layout(
    title="Diffusion Process",
    updatemenus=[{
        "type": "buttons",
        "buttons": [{
            "label": "Play",
            "method": "animate",
            "args": [None, {"frame": {"duration": SLIDER_SPEED, "redraw": True}, "fromcurrent": True}]
        }, {
            "label": "Pause",
            "method": "animate",
            "args": [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}]
        }]
    }]
)
fig.show()

In [None]:
second_result = diff.generate(
    SECOND_PROMPT,
    print_steps=False,
    decode_every_step=True,
    seed=RANDOM_SEED,
    callback_fn=visualize_diffusion,
    callback_args=["current_step", "total_steps", "image"]
)

In [None]:
fig = px.imshow(np.array(second_result.image_list), animation_frame=0, width=800, height=600)
fig.update_layout(
    title="Diffusion Process",
    updatemenus=[{
        "type": "buttons",
        "buttons": [{
            "label": "Play",
            "method": "animate",
            "args": [None, {"frame": {"duration": SLIDER_SPEED, "redraw": True}, "fromcurrent": True}]
        }, {
            "label": "Pause",
            "method": "animate",
            "args": [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}]
        }]
    }]
)
fig.show()

In [None]:
first_embedding = diff.encode_text(FIRST_PROMPT)
second_embedding = diff.encode_text(SECOND_PROMPT)

interpolated_embeddings = []
for alpha in np.linspace(0, 1, LINEAR_INTERPOLATION_STEPS):
    interpolated_embedding = (1 - alpha) * first_embedding + alpha * second_embedding
    interpolated_embeddings.append(interpolated_embedding)

In [None]:
interpolated_results = []

for embedding in tqdm(interpolated_embeddings):
    result = diff.generate(
        embedding,
        print_steps=True,
        decode_every_step=True,
        seed=RANDOM_SEED
    )
    interpolated_results.append(result)