### Imports

In [None]:
from interpolations import *
from helper import *
from pathlib import Path
from IPython import display
import json

HF_HOME = '../.hf_home'
RANDOM_SEED = 472
LINEAR_INTERPOLATION_STEPS = 100
SAVE_PREFIX = Path("./results")

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
os.environ['HF_HOME'] = HF_HOME
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

MODEL_ID = "stabilityai/stable-diffusion-2-1"
diff = CustomDiffusion(MODEL_ID, device=device, width=768, height=768)

### Helper functions

In [None]:
def run_inference(embeddings):
    results = []
    for emb in tqdm(embeddings):
        result = diff.generate(
            emb,
            NEGATIVE_PROMPT,
            print_steps=False,
            decode_every_step=False,
            seed=RANDOM_SEED
        )
        results.append(result)
    return results

### Load test data

In [None]:
with open("./experiments.json", 'r') as f:
    data = json.loads(f.read())
experiments = data['experiments']
NEGATIVE_PROMPT = data['common_negative_base']

### Run inference for whole data

In [None]:
EXEC_INFERENCE = not os.path.exists(SAVE_PREFIX)

if EXEC_INFERENCE:
    SAVE_PREFIX.mkdir(parents=True)

    for experiment in experiments:
        id = experiment['id']
        name = experiment['name']
        first_prompt = experiment['prompt_a']['positive']
        second_prompt = experiment['prompt_b']['positive']

        CURRENT_SAVE_DIR = SAVE_PREFIX / name
        if not os.path.exists(CURRENT_SAVE_DIR):
            CURRENT_SAVE_DIR.mkdir(parents=True)

        first_result = diff.generate(
            first_prompt,
            NEGATIVE_PROMPT,
            print_steps=False,
            decode_every_step=True,
            seed=RANDOM_SEED,
        )
        first_result.image.save(CURRENT_SAVE_DIR / "first_image.png")
        visualize_diffusion_progress(first_result, save_path=CURRENT_SAVE_DIR / "first_diffusion_progress.html")
        display.clear_output(wait=False)

        second_result = diff.generate(
            second_prompt,
            NEGATIVE_PROMPT,
            print_steps=False,
            decode_every_step=True,
            seed=RANDOM_SEED,
        )
        second_result.image.save(CURRENT_SAVE_DIR / "second_image.png")
        visualize_diffusion_progress(second_result, save_path=CURRENT_SAVE_DIR / "second_diffusion_progress.html")
        display.clear_output(wait=False)


        start_embedding = diff.encode_text(first_prompt)
        end_embedding = diff.encode_text(second_prompt)

        interpolated_embeddings = interpolate(start_embedding, end_embedding, LINEAR_INTERPOLATION_STEPS)
        interpolation_results = run_inference(interpolated_embeddings)
        visualize_interpolation(interpolation_results, save_path=CURRENT_SAVE_DIR / "linear_interpolation.html",
                                show=False)
        display.clear_output(wait=False)


        interpolated_embeddings_slerp = interpolate_slerp(start_embedding, end_embedding, LINEAR_INTERPOLATION_STEPS)
        interpolation_results_slerp = run_inference(interpolated_embeddings_slerp)
        visualize_interpolation(interpolation_results_slerp, save_path=CURRENT_SAVE_DIR / "slerp_interpolation.html",
                                show=False)
        display.clear_output(wait=False)


        interpolated_embeddings_cog = interpolate_cog(start_embedding, end_embedding, LINEAR_INTERPOLATION_STEPS)
        interpolation_results_cog = run_inference(interpolated_embeddings_cog)
        visualize_interpolation(interpolation_results_cog, save_path=CURRENT_SAVE_DIR / "cog_interpolation.html",
                                show=False)
        display.clear_output(wait=False)


        interpolated_embeddings_noisediffusion = interpolate_noisediffusion(start_embedding, end_embedding,
                                                                            LINEAR_INTERPOLATION_STEPS,
                                                                            noise_level=0.1)
        interpolation_results_noisediffusion = run_inference(interpolated_embeddings_noisediffusion)
        visualize_interpolation(interpolation_results_noisediffusion,
                                save_path=CURRENT_SAVE_DIR / "noisediffusion_interpolation.html",
                                show=False)
        display.clear_output(wait=False)

### Single data inference

#### Diffusion Process Visualization

In [None]:
FIRST_PROMPT = experiments[0]['prompt_a']['positive']
FIRST_NEGATIVE_PROMPT = NEGATIVE_PROMPT
SECOND_PROMPT = experiments[0]['prompt_b']['positive']
SECOND_NEGATIVE_PROMPT = NEGATIVE_PROMPT

In [None]:
first_result = diff.generate(
    FIRST_PROMPT,
    FIRST_NEGATIVE_PROMPT,
    print_steps=False,
    decode_every_step=True,
    seed=RANDOM_SEED,
    callback_fn=visualize_diffusion,
    callback_args=["current_step", "total_steps", "image"]
)

In [None]:
visualize_diffusion_progress(first_result)

In [None]:
second_result = diff.generate(
    SECOND_PROMPT,
    SECOND_NEGATIVE_PROMPT,
    print_steps=False,
    decode_every_step=True,
    seed=RANDOM_SEED,
    callback_fn=visualize_diffusion,
    callback_args=["current_step", "total_steps", "image"]
)

In [None]:
visualize_diffusion_progress(second_result)

#### Embedding Interpolation using various methods

In [None]:
start_embedding = diff.encode_text(FIRST_PROMPT)
end_embedding = diff.encode_text(SECOND_PROMPT)

In [None]:
interpolated_embeddings = interpolate(start_embedding, end_embedding, 200)
interpolation_results = run_inference(interpolated_embeddings)
visualize_interpolation(interpolation_results)

In [None]:
interpolated_embeddings_slerp = interpolate_slerp(start_embedding, end_embedding, 200)
interpolation_results_slerp = run_inference(interpolated_embeddings_slerp)
visualize_interpolation(interpolation_results_slerp)

In [None]:
interpolated_embeddings_cog = interpolate_cog(start_embedding, end_embedding, 200)
interpolation_results_cog = run_inference(interpolated_embeddings_cog)
visualize_interpolation(interpolation_results_cog)

In [None]:
interpolated_embeddings_noisediffusion = interpolate_noisediffusion(diff.encode_text(FIRST_PROMPT),
                                                                    diff.encode_text(SECOND_PROMPT), 200,
                                                                    noise_level=0.1)
interpolation_results_noisediffusion = run_inference(interpolated_embeddings_noisediffusion)
visualize_interpolation(interpolation_results_noisediffusion)