# Playground

In [None]:
!gpustat

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from common_utils.notebook_utils import *

In [None]:
from torch.utils.data import DataLoader
from PIL import Image
import pytorch_lightning as pl
import os
import copy
import random
from common_utils.resize_right import resize
from datasets.cropset import CropSet
from diffusion.diffusion import Diffusion
from diffusion.conditional_diffusion import ConditionalDiffusion
from diffusion.sr_diffusion import SRDiffusion
from diffusion.diffusion_utils import save_diffusion_sample
from models.zssr import ZSSRNet
from models.unet import Unet
from common_utils.video import html_vid
from models.nextnet import NextNet
from metrics.sifid_score import get_sifid_scores
from common_utils.ben_image import imread
from models.modules import *
import torchvision.io
import imageio
import matplotlib.pyplot as plt
import cv2

In [None]:
def show_sample(sample, figsize_mult=5):
    s = (sample.clamp(-1, 1) + 1) / 2
    s = (s * 255).type(torch.uint8).moveaxis(1, 3)
    s = s.cpu().numpy()
    
    grid_h = int(len(s) ** 0.5)
    grid_w = len(s) // grid_h
    
    if len(s) > 1:
        f, axarr = plt.subplots(grid_h, grid_w, figsize=(figsize_mult * grid_h, figsize_mult * grid_w))
        
        for idx, img in enumerate(s):
            if grid_h == 1:
                axarr[idx].imshow(img)
            else:
                axarr[idx // grid_h, idx % grid_h].imshow(img)
    else:
        plt.figure(figsize=None if figsize_mult==5 else (figsize_mult, figsize_mult))
        plt.imshow(s[0])
        
    plt.show()
    
def make_gif(samples, output_path):
    s = (samples.clamp(-1, 1) + 1) / 2
    s = (s * 255).type(torch.uint8).moveaxis(1, -1)
    imageio.mimsave(output_path, list(s.cpu().numpy()), fps=5)
    
def show_gif(samples, interval=25):
    vid = samples.transpose(1,0).unsqueeze(0)
    anim = html_vid(vid, interval=interval)
    display(HTML(f"""<table><tr><td>{anim.to_html5_video()}</td></tr></table>"""))
    
def combine_images(main_img, addition_img, location):
    assert len(main_img.shape) == 3
    assert addition_img.shape[0] == 4, 'Additional image must have an opacity channel'
    combined_img = main_img.clone()
    opacity = addition_img[3]
    combined_img[:, location[0]:location[0] + addition_img.shape[-2], location[1]:location[1] + addition_img.shape[-1]] = \
        (combined_img[:, location[0]:location[0] + addition_img.shape[-2], location[1]:location[1] + addition_img.shape[-1]]) * (1 - opacity) + addition_img[:3, :, :] * opacity
    return combined_img

def save_sample(sample, output_path):
    s = (sample.clamp(-1, 1) + 1) / 2
    s = (s * 255).type(torch.uint8).moveaxis(1, 3).cpu().numpy()
    Image.fromarray(s[0]).save(output_path)
    
def noise_img(img, model, t):
    batch_size = img.shape[0]
    if isinstance(model, Diffusion):
        noisy_img = model.q_sample(img, t)
    elif isinstance(model, ConditionalDiffusion):
        continuous_sqrt_alpha_hat = torch.FloatTensor(np.random.uniform(model.sqrt_alphas_hat_prev[t - 1], model.sqrt_alphas_hat_prev[t], size=batch_size)).to(img.device).view(batch_size, -1)
        noisy_img = model.q_sample(img, continuous_sqrt_alpha_hat.view(-1, 1, 1, 1))
    else:
        raise Exception

    return noisy_img
    
def torchvid2mp4(vid, path, fps=10):
    """ vid is CTHW """
    torchvision.io.write_video(path, tensor2npimg(vid[None, ...], to_numpy=False).permute(1, 2, 3, 0), fps=fps)

# Method - Pyramid Generation

## Debugging - Attempt to improve level0

In [None]:
path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/57_nextnet_all_layers-500-ts-much-faster-sampling/checkpoints'

level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), 
                                          model=NextNet(), timesteps=500).to(device='cuda:0')

s0 = level0.sample((24, 33), batch_size=16)
#s0 = level0.sample((100, 100), batch_size=1)
show_sample(s0)

### Full sampling attempt

In [None]:
path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/57_nextnet_all_layers-500-ts-much-faster-sampling/checkpoints'
image_name = 'balloons.png'
levels = 5
coarsest_size_ratio = 0.135
size_ratios = coarsest_size_ratio ** (1.0 / (levels - 1))

level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), 
                                                model=NextNet(), timesteps=500).to(device='cuda:0')
level1 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=1-step=34999.ckpt'), 
                                          model=NextNet(in_channels=6, depth=8), timesteps=500, strict=False).to(device='cuda:0')
level2 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=2-step=29999.ckpt'), 
                                          model=NextNet(in_channels=6, depth=8), timesteps=500, strict=False).to(device='cuda:0')
level3 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=3-step=29999.ckpt'), 
                                          model=NextNet(in_channels=6, depth=8), timesteps=500, strict=False).to(device='cuda:0')
level4 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=4-step=29999.ckpt'), 
                                          model=NextNet(in_channels=6, depth=8), timesteps=500, strict=False).to(device='cuda:0')

In [None]:
batch = 1
sizes = [(186, 248)]
for i in range(1, levels):
    sizes.insert(0, (int(sizes[0][0] * size_ratios), int(sizes[0][1] * size_ratios)))

s0 = level0.sample(image_size=sizes[0], batch_size=batch)
show_sample(s0)
s0_r = resize(s0, out_shape=sizes[1])

s1 = level1.sample(s0_r)
show_sample(s1)
s1_r = resize(s1, out_shape=sizes[2])

s2 = level2.sample(s1_r)
show_sample(s2)
s2_r = resize(s2, out_shape=sizes[3])

s3 = level3.sample(s2_r)
show_sample(s3)
s3_r = resize(s3, out_shape=sizes[4])

s4 = level4.sample(s3_r)
show_sample(s4, 5)

In [None]:
from metrics.sifid_score import get_sifid_scores
from common_utils.ben_image import imread

orig_image = imread(f'images/{image_name}').to(device=device)
normalized_samples = (s4.clamp(-1, 1).unsqueeze(1) + 1) / 2
scores =  get_sifid_scores(orig_image, normalized_samples)
print('SIFID ', scores.mean(), scores)

## Debugging - DDIM sampling attempt

In [None]:
batch = 1
levels = 5
image_name = 'balloons.png'
coarsest_size_ratio = 0.135
size_ratios = coarsest_size_ratio ** (1.0 / (levels - 1))
sizes = [(186, 248)]
for i in range(1, levels):
    sizes.insert(0, (int(sizes[0][0] * size_ratios), int(sizes[0][1] * size_ratios)))
    
path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/60_nextnet_all_layers-500-ts-with-1.0-ddim-recon-loss-recontrain/checkpoints'
level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), model=NextNet(), timesteps=500).to(device='cuda:0')

In [None]:
with torch.no_grad():
    x_T = torch.randn((batch, 3, sizes[0][0], sizes[0][1]), device='cuda:0')
    #s0 = level0.sample_ddim(x_T=x_T, sampling_step_size=100)
    #show_sample(s0)
    #s0 = level0.sample_ddim(x_T=x_T, sampling_step_size=50)
    #show_sample(s0)
    s0 = level0.sample_ddim(x_T=x_T, sampling_step_size=10)
    show_sample(s0)

In [None]:
with torch.no_grad():
    x_T2 = torch.randn((batch, 3, sizes[0][0], sizes[0][1]), device='cuda:0')
    #s0 = level0.sample_ddim(x_T=x_T2, sampling_step_size=100)
    #show_sample(s0)
    #s0 = level0.sample_ddim(x_T=x_T2, sampling_step_size=50)
    #show_sample(s0)
    s0 = level0.sample_ddim(x_T=x_T2, sampling_step_size=10)
    show_sample(s0)

In [None]:
# Manual linear interpolation

interp_size = 100
samples = []

with torch.no_grad():
    for i in range(interp_size):
        x_T_interp = x_T * (1 - (i / (interp_size - 1))) + x_T2 * (i / (interp_size - 1)) 
        s0 = level0.sample_ddim(x_T=x_T_interp, sampling_step_size=10)
        samples.append(s0.unsqueeze(0).cpu())

samples = torch.cat(samples, dim=0)

In [None]:
make_gif(samples, '/home/yanivni/data/tmp/ddim_interp_trained_on_recon.gif')

In [None]:
# Run the above 100 times to generate 100 gifs and understand if what we see is accidental or shows a cool phenomena
gif_count = 100
batch = 1
out_dir = '/home/yanivni/data/tmp/ddim_interp_trained_on_recon'

with torch.no_grad():
    for gif_index in range(gif_count):
        x_T = torch.randn((batch, 3, sizes[0][0], sizes[0][1]), device='cuda:0')
        x_T2 = torch.randn((batch, 3, sizes[0][0], sizes[0][1]), device='cuda:0')
        s0 = level0.sample_ddim(x_T=x_T, sampling_step_size=10)
        s0_2 = level0.sample_ddim(x_T=x_T2, sampling_step_size=10)
        
        save_sample(s0, output_path=fr'{out_dir}/{gif_index}_a1.png')
        save_sample(s0_2, output_path=fr'{out_dir}/{gif_index}_a2.png')
        
        # Interpolate
        interp_size = 100
        samples = []
        for i in range(interp_size):
            x_T_interp = x_T * (1 - (i / (interp_size - 1))) + x_T2 * (i / (interp_size - 1))
            s_interp = level0.sample_ddim(x_T=x_T_interp, sampling_step_size=10)
            samples.append(s_interp.unsqueeze(0).cpu())
        samples = torch.cat(samples, dim=0)

        make_gif(samples, f'{out_dir}/{gif_index}_gif.gif')

In [None]:
# DDIM spherical linear interpolation
samples = level0.sample_interpolate(image_size=sizes[0], batch_size=2, interp_seq_len=100)

In [None]:
for i in range(samples.shape[0]):
    make_gif(samples[i], f'/home/yanivni/data/tmp/ddim_interp_trained_on_recon_spherical/{i}.gif')

## Experiment - Deterministic noising and denoising process

In [None]:
path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/58_nextnet_all_layers-500-ts-with-0.8-ddim-recon-loss-0.2-recontrain/checkpoints'
image_name = 'balloons.png'
levels = 5
coarsest_size_ratio = 0.135
size_ratios = coarsest_size_ratio ** (1.0 / (levels - 1))

level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), model=NextNet(), timesteps=500)#.to(device='cuda:0')

img_level0 = imread(f'./images/balloons_level0.png')#.to(device=device)
img_level0_normalized = (img_level0 * 2) - 1
img_level0_rotated = img_level0_normalized.roll(40, dims=[3])
show_sample(img_level0_normalized)
show_sample(img_level0_rotated)

t = torch.full((1, ), 249, dtype=torch.int64)#, device=device)
noise = torch.randn_like(img_level0)
x_noisy_1 = level0.q_sample(x_start=img_level0_normalized, t=t, noise=noise)
x_noisy_2 = level0.q_sample(x_start=img_level0_rotated, t=t, noise=noise)

with torch.no_grad():
    s0_1 = level0.sample_ddim(x_T=x_noisy_1, sampling_step_size=10)
    s0_2 = level0.sample_ddim(x_T=x_noisy_2, sampling_step_size=10)
    show_sample(x_noisy_1)
    show_sample(s0_1)
    show_sample(s0_2)

In [None]:
samples = level0.sample_interpolate(image_size=sizes[0], batch_size=1, interp_seq_len=100, x_T1=x_noisy_1, x_T2=x_noisy_2)
make_gif(samples[0], f'/home/yanivni/data/tmp/ddim_interp_trained_on_recon_spherical/move_balloons.gif')

## Experiment - try denoising a noised version of image A from model trained on B

In [None]:
# Model trained on balloons, sampled on lightning

path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/60_nextnet_all_layers-500-ts-with-1.0-ddim-recon-loss-recontrain/checkpoints'
image_name = 'balloons.png'
levels = 5
coarsest_size_ratio = 0.135
size_ratios = coarsest_size_ratio ** (1.0 / (levels - 1))
sizes = [(186, 248)]
for i in range(1, levels):
    sizes.insert(0, (int(sizes[0][0] * size_ratios), int(sizes[0][1] * size_ratios)))

level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), model=NextNet(), timesteps=500)#.to(device='cuda:0')

img_level0 = imread(f'./images/lightning1_level0.png')
img_level0_normalized = (img_level0 * 2) - 1
show_sample(img_level0_normalized)
t = 150
t_tensor = torch.full((1, ), t, dtype=torch.int64)
x_noisy_1 = level0.q_sample(x_start=img_level0_normalized, t=t_tensor)
with torch.no_grad():
    s0_1 = level0.sample(image_size=sizes[0], batch_size=1, custom_initial_img=x_noisy_1, custom_timesteps=t)
    show_sample(x_noisy_1)
    show_sample(s0_1)

In [None]:
# Model trained on starry night, sampled on balloons (this cell is just to show some starry night samples, next cell is the style transfer attempt)

path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/starry_night.png/version_1/checkpoints'
image_name = 'starry_night.png'
levels = 5
coarsest_size_ratio = 0.135
size_ratios = coarsest_size_ratio ** (1.0 / (levels - 1))
sizes = [(201, 256)]
for i in range(1, levels):
    sizes.insert(0, (int(sizes[0][0] * size_ratios), int(sizes[0][1] * size_ratios)))

level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), model=NextNet(), timesteps=500)
s0 = level0.sample(image_size=sizes[0], batch_size=4)
show_sample(s0)

In [None]:
img_level0 = imread(f'./images/balloons_level0.png')
img_level0_normalized = (img_level0 * 2) - 1
show_sample(img_level0_normalized)
t = 40
t_tensor = torch.full((1, ), t, dtype=torch.int64)
x_noisy_1 = level0.q_sample(x_start=img_level0_normalized, t=t_tensor)
with torch.no_grad():
    s0_1 = level0.sample(image_size=sizes[0], batch_size=1, custom_initial_img=x_noisy_1, custom_timesteps=t)
    show_sample(x_noisy_1)
    show_sample(s0_1)

## Experiment - try denoising with a linear combination of weights from two models

In [None]:
starry_night = Diffusion.load_from_checkpoint(r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/starry_night.png/version_1/checkpoints/level=0-step=29999.ckpt', model=NextNet(), timesteps=500)
balloons = Diffusion.load_from_checkpoint(r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/58_nextnet_all_layers-500-ts-with-0.8-ddim-recon-loss-0.2-recontrain/checkpoints/level=0-step=29999.ckpt', model=NextNet(), timesteps=500)

coarsest_size_ratio = 0.135
size_ratios = coarsest_size_ratio ** (1.0 / (levels - 1))
sizes = [(200, 256)]
for i in range(1, levels):
    sizes.insert(0, (int(sizes[0][0] * size_ratios), int(sizes[0][1] * size_ratios)))
    
def interpolate_model(model_a, model_b):
    interpolated_model = copy.deepcopy(model_a)
    interpolated_state_dict = interpolated_model.state_dict()

    for k, v in interpolated_model.named_parameters():
        interpolated_state_dict[k] = (model_a.state_dict()[k] + model_b.state_dict()[k]) / 2
    
    interpolated_model.load_state_dict(interpolated_state_dict)
    return interpolated_model

def half_layers(model_a, model_b):
    new_model = copy.deepcopy(model_a)
    new_state_dict = new_model.state_dict()

    for k, v in new_model.named_parameters():
        if any([k.startswith(f'model.layers.{i}') for i in range(0, 1)]):
            new_state_dict[k] = model_b.state_dict()[k]
    
    new_model.load_state_dict(new_state_dict)
    return new_model
    
interpolated_model = interpolate_model(balloons, starry_night)
interpolated_model2 = interpolate_model(starry_night, balloons)
half_model1 = half_layers(balloons, starry_night)
half_model2 = half_layers(starry_night, balloons)
    
s0 = half_model1.sample(image_size=sizes[0], batch_size=4)
show_sample(s0)
s0 = half_model2.sample(image_size=sizes[0], batch_size=4)
show_sample(s0)

## Attempt - improving SR layer

In [None]:
path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/59_nextnet_all_layers-500-ts-with-1.0-ddim-recon-loss-full-recontrain/checkpoints'
image_name = 'balloons.png'
levels = 5
coarsest_size_ratio = 0.135
size_ratios = coarsest_size_ratio ** (1.0 / (levels - 1))
batch = 1
sizes = [(186, 248)]
imgs = [(imread(f'./images/{image_name}').to(device=device) * 2) - 1]

for i in range(1, levels):
    imgs.insert(0, resize(imgs[0], scale_factors=size_ratios))
    sizes.insert(0, imgs[0].shape[-2:])
    
    
show_factor = 10 # Just some constant for the show_sample func

print(imgs[1].mean())
level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), 
                                                model=NextNet(), timesteps=500).to(device='cuda:0')
level1 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=1-step=29999.ckpt'), 
                                          model=NextNet(in_channels=6, depth=9), recon_loss_factor=1, recon_image=imgs[1], recon_image_lr=resize(imgs[0], out_shape=imgs[1].shape), timesteps=500).to(device='cuda:0')
print(imgs[1].mean())
print(((imgs[1] * 2) - 1).mean())
level2 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=2-step=29999.ckpt'), 
                                          model=NextNet(in_channels=6, depth=9), recon_loss_factor=1, recon_image=imgs[2], recon_image_lr=resize(imgs[1], out_shape=imgs[2].shape), timesteps=500).to(device='cuda:0')
level3 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=3-step=29999.ckpt'), 
                                          model=NextNet(in_channels=6, depth=9), timesteps=500, strict=False).to(device='cuda:0')
level4 = TheirsSRDiffusion.load_from_checkpoint(os.path.join('/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/57_nextnet_all_layers-500-ts-much-faster-sampling/checkpoints', 'level=4-step=29999.ckpt'), 
                                          model=NextNet(in_channels=6, depth=8), timesteps=500, strict=False).to(device='cuda:0')


with torch.no_grad():
    #s0 = level0.sample(image_size=sizes[0], batch_size=batch)
    s0 = imgs[0]
    s0_r = resize(s0, out_shape=sizes[1])
    show_sample(torch.cat([s0, imgs[0]]), show_factor)
    print(sifid(0, s0))

    s1_ddpm = level1.sample(resize(s0, out_shape=sizes[1]), image_size=sizes[1])
    #s1 = level1.sample_ddim(lr=s0_r, x_T=torch.randn_like(s0_r), sampling_step_size=5)
    s1 = level1.sample_ddim(lr=s0_r, x_T=level1.recon_noise, sampling_step_size=10)
    s1_r = resize(s1, out_shape=sizes[2])
    show_sample(torch.cat([s1, s1_ddpm, imgs[1]]), show_factor)
    print(sifid(1, s1), sifid(1, s1_ddpm))

    s2_ddpm = level2.sample(resize(s1_ddpm, out_shape=sizes[2]), image_size=sizes[2])
    s2 = level2.sample_ddim(lr=s1_r, x_T=level2.recon_noise, sampling_step_size=10)
    s2_r = resize(s2, out_shape=sizes[3])
    show_sample(torch.cat([s2, s2_ddpm, imgs[2]]), show_factor)
    print(sifid(2, s2), sifid(2, s2_ddpm))

    s3_ddpm = level3.sample(resize(s2_ddpm, out_shape=sizes[3]), image_size=sizes[3])
    s3 = level3.sample_ddim(lr=s2_r, x_T=torch.randn_like(s2_r), sampling_step_size=10)
    s3_r = resize(s3, out_shape=sizes[4])
    show_sample(torch.cat([s3, s3_ddpm, imgs[3]]), show_factor)
    print(sifid(3, s3), sifid(3, s3_ddpm))

    s4_ddpm = level4.sample(resize(s3_ddpm, out_shape=sizes[4]), image_size=sizes[4])
    s4 = level4.sample_ddim(lr=s3_r, x_T=torch.randn_like(s3_r), sampling_step_size=10)
    show_sample(torch.cat([s4, s4_ddpm, imgs[4]]), show_factor)
    print(sifid(4, s4), sifid(4, s4_ddpm))

In [None]:
# Sanity check for recon-loss trained model. Using the recon_noise as initial sampling x_T, the results should be fairly similar to original image.
path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/version_1/checkpoints'

show_factor = 15 # Just some constant for the show_sample func
image_name = 'balloons.png'
levels = 5
coarsest_size_ratio = 0.135
size_ratios = coarsest_size_ratio ** (1.0 / (levels - 1))
batch = 1

sizes = [(186, 248)]
imgs = [(imread(f'./images/{image_name}').to(device=device) * 2) - 1]

for i in range(1, levels):
    imgs.insert(0, resize(imgs[0], scale_factors=size_ratios))
    sizes.insert(0, imgs[0].shape[-2:])
    
sifid = lambda i,s: get_sifid_scores((imgs[i].clamp(-1, 1) + 1) / 2, (s.clamp(-1, 1).unsqueeze(1) + 1) / 2)

level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), 
                                                model=NextNet(), timesteps=500).to(device='cuda:0')

level1 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=1-step=29999.ckpt'),
                                          model=NextNet(in_channels=6, depth=9), timesteps=500, recon_loss_factor=1, recon_image=imgs[1], recon_image_lr=resize(imgs[0], out_shape=imgs[1].shape)
                                               ).to(device='cuda:0')

level2 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=2-step=29999.ckpt'),
                                          model=NextNet(in_channels=6, depth=9), timesteps=500, recon_loss_factor=1, recon_image=imgs[2], recon_image_lr=resize(imgs[1], out_shape=imgs[2].shape)
                                               ).to(device='cuda:0')

level3 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=3-step=29999.ckpt'),
                                          model=NextNet(in_channels=6, depth=9), timesteps=500, recon_loss_factor=1, recon_image=imgs[3], recon_image_lr=resize(imgs[2], out_shape=imgs[3].shape)
                                               ).to(device='cuda:0')

level4 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=4-step=29999.ckpt'),
                                          model=NextNet(in_channels=6, depth=9), timesteps=500, recon_loss_factor=1, recon_image=imgs[4], recon_image_lr=resize(imgs[3], out_shape=imgs[4].shape)
                                               ).to(device='cuda:0')

with torch.no_grad():
    # The next line generates the approximation to imgs[0] 
    # s0 = level0.sample_ddim(x_T=level0.recon_noise, sampling_step_size=50)
    # Each of the next two lines generate a random level0 sample, instead of approximation imgs[0]
    #s0 = level0.sample(image_size=sizes[0], batch_size=1)
    s0 = level0.sample_ddim(x_T=torch.randn_like(level0.recon_noise), sampling_step_size=50)
    s0_r = resize(s0, out_shape=sizes[1])
    show_sample(torch.cat([s0, imgs[0]]), show_factor)
    print(sifid(0, s0))
    
    s1_ddpm = level1.sample(resize(s0, out_shape=sizes[1]))
    s1 = level1.sample_ddim(lr=s0_r, x_T=level1.recon_noise, sampling_step_size=10)
    # s1 = level1.sample_ddim(lr=s0_r, x_T=torch.randn_like(level1.recon_noise), sampling_step_size=50)
    s1_r = resize(s1, out_shape=sizes[2])
    show_sample(torch.cat([s1, s1_ddpm, imgs[1]]), show_factor)
    print(sifid(1, s1), sifid(1, s1_ddpm))
    
    s2_ddpm = level2.sample(resize(s1_ddpm, out_shape=sizes[2]))
    s2 = level2.sample_ddim(lr=s1_r, x_T=level2.recon_noise, sampling_step_size=10)
    # s2 = level2.sample_ddim(lr=s1_r, x_T=torch.randn_like(level2.recon_noise), sampling_step_size=50)
    s2_r = resize(s2, out_shape=sizes[3])
    show_sample(torch.cat([s2, s2_ddpm, imgs[2]]), show_factor)
    print(sifid(2, s2), sifid(2, s2_ddpm))
    
    s3_ddpm = level3.sample(resize(s2_ddpm, out_shape=sizes[3]))
    s3 = level3.sample_ddim(lr=s2_r, x_T=level3.recon_noise, sampling_step_size=10)
    # s3 = level3.sample_ddim(lr=s2_r, x_T=torch.randn_like(level3.recon_noise), sampling_step_size=50)
    s3_r = resize(s3, out_shape=sizes[4])
    show_sample(torch.cat([s3, s3_ddpm, imgs[3]]), show_factor)
    print(sifid(3, s3), sifid(3, s3_ddpm))
    
    s4_ddpm = level4.sample(resize(s3_ddpm, out_shape=sizes[4]))
    s4 = level4.sample_ddim(lr=s3_r, x_T=level4.recon_noise, sampling_step_size=10)
    # s4 = level4.sample_ddim(lr=s3_r, x_T=torch.randn_like(level4.recon_noise), sampling_step_size=50)
    show_sample(torch.cat([s4, s4_ddpm, imgs[4]]), show_factor)
    print(sifid(4, s4), sifid(4, s4_ddpm))

In [None]:
# To compare the previous cell results to old results of model trained without recon loss
path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/57_nextnet_all_layers-500-ts-much-faster-sampling/checkpoints'

show_factor = 15 # Just some constant for the show_sample func
image_name = 'balloons.png'
levels = 5
coarsest_size_ratio = 0.135
size_ratios = coarsest_size_ratio ** (1.0 / (levels - 1))
batch = 1

sizes = [(186, 248)]
imgs = [(imread(f'./images/{image_name}').to(device=device) * 2) - 1]

for i in range(1, levels):
    imgs.insert(0, resize(imgs[0], scale_factors=size_ratios))
    sizes.insert(0, imgs[0].shape[-2:])
    
sifid = lambda i,s: get_sifid_scores((imgs[i].clamp(-1, 1) + 1) / 2, (s.clamp(-1, 1).unsqueeze(1) + 1) / 2)

level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), 
                                                model=NextNet(), timesteps=500).to(device='cuda:0')

level1 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=1-step=34999.ckpt'),
                                          model=NextNet(in_channels=6, depth=8), timesteps=500).to(device='cuda:0')

level2 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=2-step=29999.ckpt'),
                                          model=NextNet(in_channels=6, depth=8), timesteps=500).to(device='cuda:0')

level3 = TheirsSRDiffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=3-step=29999.ckpt'),
                                          model=NextNet(in_channels=6, depth=8), timesteps=500).to(device='cuda:0')

with torch.no_grad():
    # Use same s0 as recon-ddim model
    show_sample(torch.cat([s0, imgs[0]]), show_factor)
    print(sifid(0, s0))
    
    #s1_ddpm_normal = level1.sample(resize(s0, out_shape=sizes[1]))
    s1_ddpm_normal = level1.sample(resize(imgs[0], out_shape=sizes[1]))
    s1_r_normal = resize(s1_ddpm_normal, out_shape=sizes[2])
    show_sample(torch.cat([s1_ddpm_normal, s1, s1_ddpm]), show_factor)
    print(sifid(1, s1_ddpm_normal), sifid(1, s1), sifid(1, s1_ddpm))
    
    s2_ddpm_normal = level2.sample(resize(s1_ddpm_normal, out_shape=sizes[2]))
    s2_r_normal = resize(s2_ddpm_normal, out_shape=sizes[3])
    show_sample(torch.cat([s2_ddpm_normal, s2, s2_ddpm]), show_factor)
    print(sifid(2, s2_ddpm_normal), sifid(2, s2), sifid(2, s2_ddpm))
    
    s3_ddpm_normal = level3.sample(resize(s2_ddpm_normal, out_shape=sizes[3]))
    s3_r_normal = resize(s3_ddpm_normal, out_shape=sizes[4])
    show_sample(torch.cat([s3_ddpm_normal, s3, s3_ddpm]), show_factor)
    print(sifid(3, s3_ddpm_normal), sifid(3, s3), sifid(3, s3_ddpm))
    
    s4_ddpm_normal = level4.sample(resize(s3_ddpm_normal, out_shape=sizes[4]))
    show_sample(torch.cat([imgs[4], s4_ddpm_normal, s4, s4_ddpm]), show_factor)
    print(sifid(4, s4_ddpm_normal), sifid(4, s4), sifid(4, s4_ddpm))

# Method - CCG

## Experiment - Basic NextNet sampling

In [None]:
show_factor = 15
image_name = 'balloons.png'
batch = 4
img = imread(f'./images/{image_name}').to(device='cuda:0')

get_sifid = lambda s: get_sifid_scores((img.clamp(-1, 1) + 1) / 2, (s.clamp(-1, 1).unsqueeze(1) + 1) / 2)

model = Diffusion.load_from_checkpoint('/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/balloons.png/2-ccg-32-...256...-32-filters-128-crops-nextnet/checkpoints/single-level-step=29999.ckpt', 
                                        model=NextNet(filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device='cuda:0')
samples = model.sample(image_size=(64, 64), batch_size=batch)
show_sample(samples)
print(get_sifid(samples))

In [None]:
with torch.no_grad():
    #ddim_samples = model.sample_ddim(image_size=(64, 64), batch_size=batch, sampling_step_size=100)
    #show_sample(ddim_samples)
    #ddim_samples = model.sample_ddim(image_size=(64, 64), batch_size=batch, sampling_step_size=50)
    #show_sample(ddim_samples)
    ddim_samples = model.sample_ddim(image_size=(256, 256), batch_size=batch, sampling_step_size=10)
    show_sample(ddim_samples)
    ddim_samples = model.sample_ddim(image_size=(64, 64), batch_size=batch, sampling_step_size=5)
    show_sample(ddim_samples)

## Attempt - Single crop conditional generation

In [None]:
show_factor = 15
image_name = 'balloons.png'
batch = 16
cs = (128, 128)
img = imread(f'./images/{image_name}').to(device='cuda:0')

get_sifid = lambda s: get_sifid_scores((img.clamp(-1, 1) + 1) / 2, (s.clamp(-1, 1).unsqueeze(1) + 1) / 2)

model = TheirsSRDiffusion.load_from_checkpoint('/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/balloons.png/6-same-as-before-100k-training_steps/checkpoints/single-level-step=99999.ckpt', 
                                       model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device='cuda:0')

from torchvision import transforms
from datasets.transforms import RandomScaleResize
transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        RandomScaleResize(),
        transforms.RandomCrop(cs, pad_if_needed=True, padding_mode='constant'),
        transforms.Lambda(lambda img: (img[:3, ] * 2) - 1)])

def get_half_noisy_crop(img):
    crop = torch.cat([transform(img) for i in range(batch)])
    h_noise_index = int(-0.5 * cs[0])
    crop[:, :, h_noise_index:, :] = torch.randn_like(crop[:, :, h_noise_index:, :])
    return crop

def combine_sample_and_original(sample, original):
    assert sample.shape == original.shape
    s = sample.clone()
    h_noise_index = int(-0.5 * cs[0])
    s[:, :, :h_noise_index, :] = original[:, :, :h_noise_index, :]
    return s

missing_sample = get_half_noisy_crop(img)
show_sample(missing_sample)

with torch.no_grad():
    #samples = model.sample_ddim(lr=missing_sample, sampling_step_size=10)
    #show_sample(samples)
    samples = model.sample(lr=missing_sample)
    show_sample(samples)
    samples = combine_sample_and_original(samples, missing_sample)
    show_sample(samples)

In [None]:
missing_sample = torch.randn_like(missing_sample)
show_sample(missing_sample)
with torch.no_grad():
    samples = model.sample(lr=missing_sample)
    show_sample(samples)
    samples = combine_sample_and_original(samples, missing_sample)
    show_sample(samples)

## Attempt - CCG full sampling

In [None]:
model = TheirsSRDiffusion.load_from_checkpoint('/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/balloons.png/6-same-as-before-100k-training_steps/checkpoints/single-level-step=99999.ckpt', 
                                       model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device='cuda:0')

cs = (128, 128)

padding_size = (cs[0] // 4, cs[1] // 4)
window_size = cs
stride = (int(cs[0] * 0.9), int(cs[1] * 0.9))
samples = model.sample_ccg(sample_size=(256, 256), batch_size=1, window_size=window_size, stride=stride, padding_size=padding_size, method='normal')
show_sample(samples, 10)

## Experiment - Harmonization with CCG-trained model

In [None]:
model = TheirsSRDiffusion.load_from_checkpoint('/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/starry_night.png/0-ccg-noise-only-in-specific-locations/checkpoints/single-level-step=49999.ckpt', 
                                       model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device='cuda:0')
for t in range(80, 180, 10):
    # Add noise only to non-transparent parts of addition_img
    #t = 140
    print(t)
    addition_img = (resize(imread(r'./images/transparent-red-balloon.png', mode='RGBA').to(device=device), scale_factors=0.11) * 2) - 1
    #addition_img = (resize(imread(r'./images/transparent-balloon.png', mode='RGBA').to(device=device), scale_factors=0.1) * 2) - 1
    #addition_img = (resize(imread(r'./images/transparent-airplane.png', mode='RGBA').to(device=device), scale_factors=0.1) * 2) - 1
    main_img = (imread('./images/starry_night.png').to(device=device) * 2) - 1
    #main_img = (imread('./images/balloons.png').to(device=device) * 2) - 1
    continuous_sqrt_alpha_cumprod = (torch.FloatTensor(np.random.uniform(model.sqrt_alphas_cumprod_prev[t - 1], model.sqrt_alphas_cumprod_prev[t], size=1)).to(device)).view(1, -1)
    addition_img_noisy = model.q_sample(x_start=addition_img[:, :3, :, :], continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1))
    addition_img = torch.cat([addition_img_noisy[0], (addition_img[0][3].unsqueeze(0) + 1) / 2])
  
    img = combine_images(main_img[0], addition_img, (25, 100)).unsqueeze(0)
    torchvision.utils.save_image((img + 1) / 2, f'/home/yanivni/data/tmp/combined_starry_night_red_balloon.png')
    with torch.no_grad():
        #show_sample(starry_night)
        #print('Noisy image')
        #show_sample(img)
        #print('DDIM sample starting from noisy img')
        samples = model.sample_ddim(lr=img, sampling_step_size=10)
        torchvision.utils.save_image((samples + 1) / 2, f'/home/yanivni/data/tmp/harmonization_starry_night_balloon_t={t}_ddim.png')
        show_sample(samples)
        print('DDPM sample starting from noisy img')
        samples = model.sample(lr=img)
        torchvision.utils.save_image((samples + 1) / 2, f'/home/yanivni/data/tmp/harmonization_starry_night_balloon_t={t}_ddpm.png')
        show_sample(samples)

## Experiment - Style transfer with CCG-trained model

In [None]:
# Model trained on starry night, sampled on balloons
content_image_name = 'buildings.jpg'
style_image_name = 'starry_night.png'
version_name = '1-baseline-ccg'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{style_image_name}/{version_name}/checkpoints/single-level-step=49999.ckpt'
saved_data_path = f'/home/yanivni/data/tmp/organized-outputs/'
model = ConditionalDiffusion.load_from_checkpoint(path, model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device='cuda:0')

#x_T = torch.randn(size=(batch, 3, 200, 200), device=device)
#samples = model.sample(x_T)
#show_sample(samples)

img = (imread(f'./images/{style_image_name}').to(device=device) * 2) - 1
show_sample(img)
img = (imread(f'./images/{content_image_name}').to(device=device) * 2) - 1
show_sample(img)

for t in range(150, 220, 20):
    print(t)
    t_vector = (torch.FloatTensor(np.random.uniform(model.sqrt_alphas_hat_prev[t - 1], model.sqrt_alphas_hat_prev[t], size=1)).to(device)).view(1, -1)
    img_noisy = model.q_sample(img, t_vector.view(-1, 1, 1, 1))
    with torch.no_grad():
        samples = model.sample_ddim(condition=img_noisy, sampling_step_size=10)
        show_sample(samples)
           
        samples = model.sample(condition=img_noisy)
        show_sample(samples)
        
        sample_directory = os.path.join(saved_data_path, 'Style Transfer', f'{style_image_name}/{content_image_name}')
        os.makedirs(sample_directory, exist_ok=True)
        save_diffusion_sample(samples, os.path.join(sample_directory, f't={t}_sample.png'))

# Output generation and preparation for meeting

In [None]:
image_names = ['penguins.jpg']

for image_name in image_names:
    version_name = '1-baseline-simple'
    path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/last.ckpt'
    saved_data_path = f'/home/yanivni/data/tmp/organized-outputs/'
    model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=16), timesteps=500, strict=False).to(device='cuda:0')

    # Generate images from complete noise in several sizes (Diverse Generation)
    sizes = [tuple(imread(f'./images/{image_name}').shape[-2:])]
    sizes.append((sizes[0][0] * 2, sizes[0][1]))
    sizes.append((sizes[0][0], sizes[0][1] * 2))
    sizes.append((sizes[0][0] * 2, sizes[0][1] * 2))

    batch = 9
    for size in sizes:
        with torch.no_grad():
            samples = model.sample(image_size=size, batch_size=batch)
            show_sample(samples)
        
            sample_directory = os.path.join(saved_data_path, 'Diverse Generation', f'{image_name}/{size[0]}x{size[1]}')
            os.makedirs(sample_directory, exist_ok=True)
            save_diffusion_sample(samples, os.path.join(sample_directory, 'sample.png'))

In [None]:
# Inpainting

image_name = 'starry_night.png'
version_name = '3-baseline-ccg-full-training-small-net'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/last.ckpt'
saved_data_path = f'/home/yanivni/data/tmp/organized-outputs/'
#model = ConditionalDiffusion.load_from_checkpoint(path, model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device='cuda:0')
model = ConditionalDiffusion.load_from_checkpoint(path, model=NextNet(in_channels=6, depth=8), timesteps=500).to(device='cuda:0')
   
def add_noise_block_to_image(img, block_size, model, t_vector):
    noise_mask = torch.zeros_like(img)
    img_h, img_w = img.shape[-2:]
    crop_start_loc = (random.randint(0, img_h - block_size[0] - 1), random.randint(1, img_h - block_size[1] - 1))
    #img[:, :, crop_start_loc[0]:crop_start_loc[0] + block_size[0], crop_start_loc[1]:crop_start_loc[1] + block_size[1]].normal_()
    
    # TODO: This performs a certain "dilation" on the noise map
    #noise_mask[:, :, max(0, crop_start_loc[0]-block_size[0]):crop_start_loc[0] + block_size[0]*2, max(0, crop_start_loc[1]-block_size[1]):crop_start_loc[1] + block_size[1]*2].fill_(0.4)
    #noise_mask[:, :, max(0, crop_start_loc[0]-int(0.5*block_size[0])):crop_start_loc[0] + int(block_size[0]*1.5), max(0, crop_start_loc[1]-int(block_size[0]*0.5)):crop_start_loc[1] + int(block_size[1]*1.5)].fill_(0.7)
    #noise_mask[:, :, crop_start_loc[0]:crop_start_loc[0] + block_size[0], crop_start_loc[1]:crop_start_loc[1] + block_size[1]].fill_(1)
    #img = torch.randn_like(img) * noise_mask + img * (1 - noise_mask)
    
    noised_img = img.clone()
    cur_t_img = noise_img(img, model, t_vector[0])
    noised_img[:, :, max(0, crop_start_loc[0]-block_size[0]):crop_start_loc[0] + block_size[0]*2, max(0, crop_start_loc[1]-block_size[1]):crop_start_loc[1] + block_size[1]*2] = \
                cur_t_img[:, :, max(0, crop_start_loc[0]-block_size[0]):crop_start_loc[0] + block_size[0]*2, max(0, crop_start_loc[1]-block_size[1]):crop_start_loc[1] + block_size[1]*2]
    cur_t_img = noise_img(img, model, t_vector[1])
    noised_img[:, :, max(0, crop_start_loc[0]-int(0.5*block_size[0])):crop_start_loc[0] + int(block_size[0]*1.5), max(0, crop_start_loc[1]-int(block_size[0]*0.5)):crop_start_loc[1] + int(block_size[1]*1.5)] = \
                cur_t_img[:, :, max(0, crop_start_loc[0]-int(0.5*block_size[0])):crop_start_loc[0] + int(block_size[0]*1.5), max(0, crop_start_loc[1]-int(block_size[0]*0.5)):crop_start_loc[1] + int(block_size[1]*1.5)] 
    cur_t_img = noise_img(img, model, t_vector[2])
    noised_img[:, :, crop_start_loc[0]:crop_start_loc[0] + block_size[0], crop_start_loc[1]:crop_start_loc[1] + block_size[1]] = \
                cur_t_img[:, :, crop_start_loc[0]:crop_start_loc[0] + block_size[0], crop_start_loc[1]:crop_start_loc[1] + block_size[1]]
    img = noised_img
    
    return img
    
# Generate some sample conditional crops (Inpainting)
inpainting_crop_sizes = [(32,32), (64, 64), (64, 128), (128, 64), (140, 140)]
batch = 4
for size in inpainting_crop_sizes:
    # Add noise to specific crop in image
    img = ((imread(fr'./images/{image_name}') * 2) - 1).to(device=device)
    img = add_noise_block_to_image(img, size, model, (150, 300, 500))
    show_sample(img)
    sample_directory = os.path.join(saved_data_path, 'Inpainting', 'non-binary-mask', f'{image_name}/{size[0]}x{size[1]}')
    os.makedirs(sample_directory, exist_ok=True)
    save_diffusion_sample(img, os.path.join(sample_directory, 'noised_image.png'))
    
    for i in range(batch):
        sample = model.sample(condition=img)
        show_sample(sample)
        save_diffusion_sample(sample, os.path.join(sample_directory, f'{i}_sample.png'))

In [None]:
# Harmonize specific objects into an image
image_name = 'seascape.png'
version_name = '1-baseline-ccg'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/single-level-step=99999.ckpt'
saved_data_path = f'/home/yanivni/data/tmp/organized-outputs'
model = ConditionalDiffusion.load_from_checkpoint(path, model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device='cuda:0')


additional_img_names_to_sizes_ratios = {
    'transparent-red-balloon.png': 0.1,
    'transparent-balloon.png': 0.1,
    'transparent-spaceship.png': 0.05,
    'transparent-airplane.png': 0.1,
    'transparent-lizard.png': 0.08,
}

harmonization_locations = {
    'starry_night.png': (25, 100),
    'balloons.png': (0, 0),
    'stone.png': (120, 110),
    'seascape.png': (10, 200)
}

for additional_img_name, ratio in additional_img_names_to_sizes_ratios.items():
    addition_img = (resize(imread(fr'./images/{additional_img_name}', mode='RGBA').to(device=device), scale_factors=ratio) * 2) - 1
    main_img = (imread(f'./images/{image_name}').to(device=device) * 2) - 1
    
    sample_directory = os.path.join(saved_data_path, 'Harmonization', f'{image_name}/{additional_img_name}')
    os.makedirs(sample_directory, exist_ok=True)
    
    dummy_additional_img = torch.cat([addition_img[0][:3], (addition_img[0][3].unsqueeze(0) + 1) / 2])
    dummy_img = combine_images(main_img[0], dummy_additional_img, harmonization_locations[image_name]).unsqueeze(0)
    show_sample(dummy_img, 15)
    save_diffusion_sample(dummy_img, os.path.join(sample_directory, 'before_harmonization.png'))
    
    for t in range(1, 501, 50):
        print(t)
        # Add noise only to non-transparent parts of addition_img
        t_vector = (torch.FloatTensor(np.random.uniform(model.sqrt_alphas_hat_prev[t - 1], model.sqrt_alphas_hat_prev[t], size=1)).to(device)).view(1, -1)
        addition_img_noisy = addition_img[:, :3, :, :]#model.q_sample(x_start=addition_img[:, :3, :, :], continuous_sqrt_alpha_hat=t_vector.view(-1, 1, 1, 1))
        addition_img_concat = torch.cat([addition_img_noisy[0], (addition_img[0][3].unsqueeze(0) + 1) / 2])
        img = combine_images(main_img[0], addition_img_concat, harmonization_locations[image_name]).unsqueeze(0)
        
        with torch.no_grad():
            samples = model.sample_ddim(condition=img, sampling_step_size=10)
            save_diffusion_sample(samples, os.path.join(sample_directory, f't={t}_ddim.png'))
            show_sample(samples, 15)

In [None]:
# Conditional inpainting (Color-guided inpainting)
image_name = 'fruit.png'
version_name = '2-baseline-ccg-full-training-normal-net'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/last.ckpt'
saved_data_path = f'/home/yanivni/data/tmp/organized-outputs'
model = ConditionalDiffusion.load_from_checkpoint(path, model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device='cuda:0')

colors = {
    'yellow': torch.tensor([[[1,1,0]]], dtype=torch.float32, device=device),
    'red': torch.tensor([[[1,0,0]]], dtype=torch.float32, device=device),
    'green': torch.tensor([[[0,1,0]]], dtype=torch.float32, device=device),
    'blue': torch.tensor([[[0,0,1]]], dtype=torch.float32, device=device)
}

# Add noise to specific crop in image
img = ((imread(fr'./images/{image_name}') * 2) - 1).to(device=device)
crop_start_loc = (74, 115)
#crop_size = (26, 52)
crop_size = (50, 110)

batch = 4
for color_name, color in colors.items():
    color_tensor = color.repeat((crop_size[0], crop_size[1], 1)).moveaxis(2, 0)
    
    #t = 250
    #t_tensor = torch.cuda.FloatTensor(np.random.uniform(model.sqrt_alphas_hat_prev[t - 1], model.sqrt_alphas_hat_prev[t], size=1)).to(device='cuda:0')
    #noisy_color_tensor = model.q_sample(x_start=color_tensor, continuous_sqrt_alpha_hat=t_tensor.view(-1, 1, 1, 1))  
    noisy_color_tensor = color_tensor + torch.randn_like(color_tensor)
    img[:, :, crop_start_loc[0]:crop_start_loc[0] + crop_size[0], crop_start_loc[1]:crop_start_loc[1] + crop_size[1]] = noisy_color_tensor
    show_sample(img)
    sample_directory = os.path.join(saved_data_path, 'Conditional Inpainting', f'{image_name}/{color_name}')
    os.makedirs(sample_directory, exist_ok=True)
    save_diffusion_sample(img, os.path.join(sample_directory, f'noised_image.png'))
    
    with torch.no_grad():
        for i in range(batch):
            sample = model.sample(condition=img)
            show_sample(sample)
            save_diffusion_sample(sample, os.path.join(sample_directory, f'{i}_sample.png'))

In [None]:
# Image Editing

image_name = 'stone.png'
edit_name = 'stone_edit.png'
version_name = '3-baseline-simple'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/last.ckpt'
saved_data_path = f'/home/yanivni/data/tmp/organized-outputs/'
model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=16), timesteps=500, strict=False).to(device='cuda:0')

edit_img = imread(f'./images/edit/{edit_name}').to(device=device) * 2 - 1
for t in range(100, 350, 20):
    print(t)
    noisy_edit_img = noise_img(edit_img, model, t)
    show_sample(noisy_edit_img)
    samples = model.sample(custom_initial_img=noisy_edit_img, custom_timesteps=t)
    show_sample(samples)
        
    sample_directory = os.path.join(saved_data_path, 'Editing', f'{image_name}/{edit_name}')
    os.makedirs(sample_directory, exist_ok=True)
    save_diffusion_sample(edit_img, os.path.join(sample_directory, 'original_edit.png'))
    save_diffusion_sample(samples, os.path.join(sample_directory, f'sample_t={t}.png'))

In [None]:
# Visual Summary

image_name = 'birds_3.jpg'
version_name = '2-baseline-vs'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/last.ckpt'
model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=11, filters_per_layer=[32, 64, 64, 128, 128, 256, 128, 128, 64, 64, 32]), timesteps=500, strict=False).to(device='cuda:0')
saved_data_path = f'/home/yanivni/data/tmp/organized-outputs'

    
def summarize_image(img, model, n_iteration=10, scale=(0.9, 0.9), t=100):
    resized_imgs = [img]
    for i in range(n_iterations):
        resized_img = resize(resized_imgs[-1], scale_factors=scale)
        fixed_img = model.sample(custom_initial_img = noise_img(resized_img, model, t), custom_timesteps=t)
        resized_imgs.append(fixed_img)
        show_sample(fixed_img)
    return resized_imgs

img = imread(f'./images/{image_name}').to(device=device) * 2 - 1 
final_size = {'tree.png': (img.shape[-2] // 4, img.shape[-1] // 4),
              'balloons.png': (img.shape[-2] // 4, img.shape[-1] // 4),
              'mountains3.png': (img.shape[-2] // 1.2, img.shape[-1] // 4),
              'birds_3.jpg': (img.shape[-2] // 1.1, img.shape[-1] // 3)}[image_name]

n_iterations = 15
scale = ((final_size[0] / img.shape[-2]) ** (1 / n_iterations), (final_size[1] / img.shape[-1]) ** (1 / n_iterations))
print(scale)

for t in [100, 150, 250]:
    print(f't={t}')
    samples = summarize_image(img, model, n_iterations, scale, t)
    sample_directory = os.path.join(saved_data_path, 'Visual Summary', f'{image_name}/t={t}_scale={scale[0] :2f}_{scale[1] :2f}')
    os.makedirs(sample_directory, exist_ok=True)
    
    for i, s in enumerate(samples):
        print(f'scale={(scale[0] ** i, scale[1] ** i)}')
        show_sample(s)
        save_diffusion_sample(s, os.path.join(sample_directory, f'scale={scale[0] ** i :2f}_{scale[1] ** i :2f}.png'))


In [None]:
# Draw from sketch

sketch_name = 'starry_night_sketch.png'
image_name = 'starry_night.png'
version_name = '2-simple-diffusion-huge-crops-nextnet-depth-16-part3'

sketch_name = 'balloons_sketch_1.png'
image_name = 'balloons.png'
version_name = '10-simple-diffusion-huge-crops-nextnet-depth-16-part4'

sketch_name = 'cows_sketch.png'
image_name = 'cows.png'
version_name = '1-baseline-simple'

path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/single-level-step=199999.ckpt'
sketch_model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=16), timesteps=500, auto_sample=False, strict=False).to(device='cuda:0')
saved_data_path = f'/home/yanivni/data/tmp/organized-outputs/'


show_sample((imread(f'./images/{image_name}').to(device=device) * 2) - 1)
sketch_img = (imread(f'./images/sketch/{sketch_name}').to(device=device) * 2) - 1
show_sample(sketch_img)

for t in range(240, 320, 20):
    print(t)
    
    with torch.no_grad():
        img_noisy = noise_img(sketch_img, sketch_model, t)
        samples = sketch_model.sample(custom_initial_img=img_noisy, custom_timesteps=t, batch_size=1)
        show_sample(samples)
        
    sample_directory = os.path.join(saved_data_path, 'Draw from sketch', f'{image_name}/{sketch_name}')
    os.makedirs(sample_directory, exist_ok=True)
    save_diffusion_sample(sketch_img, os.path.join(sample_directory, f'raw_sketch.png'))
    save_diffusion_sample(samples, os.path.join(sample_directory, f't={t}_sample.png'))

# Etc

## Experiment- Using ILVR sampling for better conditioning

In [None]:
image_name = 'balloons.png'
version_name = '8-baseline-ccg'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/single-level-step=49999.ckpt'
model = ConditionalDiffusion.load_from_checkpoint(path, model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device=device)


def phi_N(img, N):
    out_size = (img.shape[-2], img.shape[-1])
    return resize(resize(img, scale_factors=1/N), scale_factors=N, out_shape=out_size)

@torch.no_grad()
def sample_ilvr(model, reference_img, N=1):
    b = reference_img.shape[0]
    condition = torch.randn_like(reference_img)
    img = torch.randn_like(reference_img)
    for i in reversed(range(0, model.num_timesteps)):
        if i > 0:
            img_tag = model.p_sample(img, i, condition_x=condition)

            continuous_sqrt_alpha_hat = torch.FloatTensor(
                    np.random.uniform(model.sqrt_alphas_hat_prev[i - 1], model.sqrt_alphas_hat_prev[i], size=b)).to(model.device).view(b, -1)
            reference_noised = model.q_sample(reference_img, continuous_sqrt_alpha_hat)

            img = phi_N(reference_noised, N) + img_tag - phi_N(img_tag, N)
        else:
            img = model.p_sample(img, i, condition_x=condition)
    return img

    
reference_img = ((imread(fr'./images/{image_name}') * 2) - 1).to(device=device)
show_sample(reference_img)

for N in [2, 4, 8, 16, 32, 64]:
    print(f'N={N}')
    samples = sample_ilvr(model, reference_img=reference_img, N=N)
    show_sample(phi_N(reference_img, N))
    show_sample(samples)

In [None]:
# Combine ILVR with conditioning based on level0 generation from old pyramid model
image_name = 'balloons.png'
version_name = '6-same-as-before-100k-training_steps'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/single-level-step=99999.ckpt'
model_big = TheirsSRDiffusion.load_from_checkpoint(path, model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device=device)

image_name = 'balloons_medium.png'
version_name = 'version_1'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/single-level-step=99999.ckpt'
model_med = ConditionalDiffusion.load_from_checkpoint(path, model=NextNet(in_channels=6, filters_per_layer=[32, 64, 64, 128, 256, 128, 64, 64, 32]), timesteps=500).to(device=device)

path_to_checkpoints = r'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/pyramid/balloons.png/version_4/checkpoints'
image_name = 'balloons.png'
batch = 1

@torch.no_grad()
def sample_ilvr_2(model, reference_img, N=1):
    b = reference_img.shape[0]
    condition = torch.randn_like(reference_img)
    img = torch.randn_like(reference_img)
    for i in reversed(range(0, model.num_timesteps)):
        if i > 0:
            img_tag = model.p_sample(img, i, condition_x=condition)

            continuous_sqrt_alpha_hat = torch.FloatTensor(
                    np.random.uniform(model.sqrt_alphas_cumprod_prev[i - 1], model.sqrt_alphas_cumprod_prev[i], size=b)).to(model.device).view(b, -1)
            reference_noised = model.q_sample(reference_img, continuous_sqrt_alpha_hat)

            img = phi_N(reference_noised, N) + img_tag - phi_N(img_tag, N)
        else:
            img = model.p_sample(img, i, condition_x=condition)
    return img


imgs = [(imread(f'./images/balloons_medium.png').to(device=device) * 2) - 1, (imread(f'./images/balloons.png').to(device=device) * 2) - 1]

level0 = Diffusion.load_from_checkpoint(os.path.join(path_to_checkpoints, 'level=0-step=29999.ckpt'), 
                                                model=NextNet(depth=12), timesteps=500).to(device='cuda:0')


level0_size = (186, 248)
reference_img = level0.sample(image_size=level0_size, batch_size=1)
show_sample(torch.cat([reference_img, resize(imgs[-1], out_shape=reference_img.shape)]), 15)
save_diffusion_sample(reference_img, f'/home/yanivni/data/tmp/coarsest_scale_conditioning_experiment/0_coarsest_scale.png')

reference_img = resize(reference_img, out_shape=imgs[0].shape)
N = imgs[0].shape[-2] / level0_size[0]
samples = sample_ilvr(model_med, reference_img, N)
save_diffusion_sample(samples, f'/home/yanivni/data/tmp/coarsest_scale_conditioning_experiment/1_medium_scale_sample.png')
show_sample(torch.cat([samples, resize(imgs[-1], out_shape=samples.shape)]), 15)

samples = resize(samples, out_shape=imgs[1].shape)
N = imgs[1].shape[-2] / imgs[0].shape[-2]
samples = sample_ilvr_2(model_big, samples, N)
show_sample(torch.cat([samples, resize(imgs[-1], out_shape=samples.shape)]), 15)
save_diffusion_sample(samples, f'/home/yanivni/data/tmp/coarsest_scale_conditioning_experiment/2_fine_scale_sample.png')

N = imgs[1].shape[-2] / level0_size[0]
reference_img = resize(reference_img, out_shape=imgs[1].shape)
samples = sample_ilvr_2(model_big, reference_img, N)
show_sample(torch.cat([samples, resize(imgs[-1], out_shape=samples.shape)]), 15)
save_diffusion_sample(samples, f'/home/yanivni/data/tmp/coarsest_scale_conditioning_experiment/3_direct_fine_scale_sample.png')

## Experiment - More attempts at diverse generation

In [None]:
from diffusion.diffusion import Diffusion

image_name = 'starry_night.png'
version_name = 'version_0'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/single-level-step=99999.ckpt' # +75K
model1 = Diffusion.load_from_checkpoint(path, model=NextNet(depth=16), timesteps=500).to(device='cuda:0')

batch = 1
sample_size = (200, 350)
x_T = torch.randn(size=(batch, 3, sample_size[0], sample_size[1]), device=device)
samples = model1.sample_ddim(x_T=x_T, sampling_step_size=10)
show_sample(samples, 10)
samples = model1.sample(image_size=sample_size, batch_size=batch)
show_sample(samples, 10)

## Experiment - Playing around with latent space representations

In [None]:
image_name = 'balloons.png'
version_name = '10-simple-diffusion-huge-crops-nextnet-depth-16-part4'#9.0-simple-diffusion-half-resolution-huge-crops-nextnet'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/last.ckpt'
#model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=9), timesteps=500).to(device='cuda:0')
model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=16), timesteps=500, strict=False).to(device='cuda:0')

batch = 1
sample_size = (186, 250)
x_T = torch.randn(size=(batch, 3, sample_size[0], sample_size[1]), device=device)
x_T2 = torch.randn(size=(batch, 3, sample_size[0], sample_size[1]), device=device)
x_T = torch.load(f'/home/yanivni/data/tmp/latent_space_editing/{image_name}/high_res/good_xT.pt')
x_T2 = torch.load(f'/home/yanivni/data/tmp/latent_space_editing/{image_name}/high_res/good_xT2.pt')

In [None]:
sss = 5 # sampling step size
samples = model.sample_ddim(x_T=x_T, batch_size=1, sampling_step_size=sss)
show_sample(samples)    
samples = model.sample_ddim(x_T=x_T2, batch_size=1, sampling_step_size=sss)
show_sample(samples)

balloon_means = x_T2[:, :, 0:35, 75:100].mean(dim=(-2, -1)).view(1,3,1,1)  

clean_xT2 = x_T2.clone()
clean_xT2[:, :, :50, :].fill_(0)
#clean_xT2[:, :, :50, :] = torch.randn_like(clean_xT2[:, :, :50, :]) + balloon_means
samples = model.sample_ddim(x_T=clean_xT2, batch_size=1, sampling_step_size=10)
show_sample(samples)
#save_diffusion_sample(samples, f'/home/yanivni/data/tmp/latent_space_editing/changing_upper_part_of_image/positive_noise/{i}.png')
        
#clean_xT2[:, :, :45, :] = (torch.randn_like(clean_xT2[:, :, :45, :])) - 0.0438
#samples = model.sample_ddim(x_T=clean_xT2, batch_size=1, sampling_step_size=10)
#show_sample(samples)
#save_diffusion_sample(samples, f'/home/yanivni/data/tmp/latent_space_editing/changing_upper_part_of_image/negative_noise/{i}.png')
        
samples = []
#x_T2_chunk = x_T[:, :, :35, 45:70]
clean_xT = x_T.clone()
clean_xT[:, :, :150, :].fill_(0)
x_T_chunk = x_T[:, :, 50:150, :]
for i in range(0, 50):
    x_T_patched = clean_xT.clone()
    x_T_patched[:, :, 50-i:50-i+x_T_chunk.shape[-2], :x_T_chunk.shape[-1]] = x_T_chunk
    samples.append(model.sample_ddim(x_T=x_T_patched, batch_size=1, sampling_step_size=sss))
samples = torch.cat(samples, dim=0)
show_gif(samples)

In [None]:
# Run the above 10 times to generate gifs and understand if what we see is accidental or shows a cool phenomena
gif_count = 10
batch = 1
out_dir = '/home/yanivni/data/tmp/latent_space_editing/balloons.png/medium_res/interp_gifs'

with torch.no_grad():
    for gif_index in range(gif_count):
        x_T = torch.randn((batch, 3, sample_size[0], sample_size[1]), device='cuda:0')
        x_T2 = torch.randn((batch, 3, sample_size[0], sample_size[1]), device='cuda:0')
        s0 = model.sample_ddim(x_T=x_T, sampling_step_size=10)
        s0_2 = model.sample_ddim(x_T=x_T2, sampling_step_size=10)
        
        save_sample(s0, output_path=fr'{out_dir}/{gif_index}_a1.png')
        save_sample(s0_2, output_path=fr'{out_dir}/{gif_index}_a2.png')
        
        # Interpolate
        interp_size = 100
        samples = []
        for i in range(interp_size):
            x_T_interp = x_T * (1 - (i / (interp_size - 1))) + x_T2 * (i / (interp_size - 1))
            s_interp = model.sample_ddim(x_T=x_T_interp, sampling_step_size=10)
            samples.append(s_interp.cpu())
        samples = torch.cat(samples, dim=0)
        make_gif(samples, f'{out_dir}/{gif_index}_gif.gif')

In [None]:
torch.save(x_T, f'/home/yanivni/data/tmp/latent_space_editing/{image_name}/good_xT.pt')
torch.save(x_T2, f'/home/yanivni/data/tmp/latent_space_editing/{image_name}/good_xT2.pt')

In [None]:
image_name = 'birds.png'
version_name = '2-simple-diffusion-huge-crops-nextnet'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/single-level-step=49999.ckpt'
model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=12), timesteps=500).to(device='cuda:0')

batch = 1
sample_size = (200, 200)
x_T = torch.randn(size=(batch, 3, sample_size[0], sample_size[1]), device=device)
x_T2 = torch.randn(size=(batch, 3, sample_size[0], sample_size[1]), device=device)
x_T = torch.load(f'/home/yanivni/data/tmp/latent_space_editing/{image_name}/good_xT.pt')
x_T2 = torch.load(f'/home/yanivni/data/tmp/latent_space_editing/{image_name}/good_xT2.pt')

In [None]:
def small_change(x, m, v):
    h_size = int(x.shape[-2] ** 0.4)
    w_size = int(x.shape[-1] ** 0.4)
    
    h_loc, w_loc = random.randint(0, x.shape[-2] - h_size), random.randint(0, x.shape[-1] - w_size)
    x[:, :, h_loc:h_loc + h_size, w_loc:w_loc + w_size] += m #torch.randn_like(x[:, :, h_loc:h_loc + h_size, w_loc:w_loc + w_size]) * torch.sqrt(v) + m
    return x

def chunk_interp(chunks, i, lim):
    cur_idx_float = (i / lim) * (len(chunks)-1)
    cur_idx = int(cur_idx_float)
    alpha = cur_idx_float % 1
    return (1 - alpha) * chunks[cur_idx] + alpha * chunks[cur_idx + 1]

def rotate_interp(x, chunks, i, lim):
    rotated_x = x.roll(i, dims=(-1))
    
    new_chunk = chunk_interp(chunks, i, lim)
    dir_x = new_chunk - chunks[0]
    dir_x = dir_x.repeat((1, 1, 7, 7))
    
    rotated_x = rotated_x + dir_x[:, :, :rotated_x.shape[-2], :rotated_x.shape[-1]]
    return rotated_x
    
with torch.no_grad():
    samples = model.sample_ddim(x_T=x_T, sampling_step_size=10)
    show_sample(samples)    
    samples = model.sample_ddim(x_T=x_T2, sampling_step_size=10)
    show_sample(samples)
    
    clean_xT2 = x_T2.clone()
    #clean_xT2[:, :, :, :].fill_(0)
    samples = model.sample_ddim(x_T=clean_xT2, sampling_step_size=10)
    
    samples = []
    chunks = [x_T[:, :, 80:110, 95:125], x_T2[:, :, 80:110, 70:100], x_T[:, :, 55:85, 40:70], x_T[:, :, 55:85, 120:150], x_T[:, :, 50:80, 85:115],
              x_T[:, :, 80:110, 95:125], x_T2[:, :, 80:110, 70:100], x_T[:, :, 55:85, 40:70]]
    #m, v = chunk.mean(), x_T2_chunk.var()
    lim = 120
    for i in range(0, lim):
        x_T_patched = clean_xT2.clone()
        #chunk = small_change(chunk, m, v)
        chunk = chunk_interp(chunks, i, lim)
        #x_T_patched[:, :, 20:20+chunk.shape[-2], i:i+chunk.shape[-1]] = chunk
        x_T_patched[:, :, :chunk.shape[-2], i:i+chunk.shape[-1]] += (chunk - chunks[0])
        x_T_patched[:, :, 30:30+chunk.shape[-2], i:i+chunk.shape[-1]] += (chunk - chunks[0])
        x_T_patched[:, :, 60:60+chunk.shape[-2], i:i+chunk.shape[-1]] += (chunk - chunks[0])
        #x_T_patched[:, :, :150, :] = rotate_interp(clean_xT2[:, :, :150, :], chunks, i, lim)
        samples.append(model.sample_ddim(x_T=x_T_patched, sampling_step_size=10))
    samples = torch.cat(samples, dim=0)
    show_gif(samples)

In [None]:
torch.save(x_T, f'/home/yanivni/data/tmp/latent_space_editing/{image_name}/good_xT.pt')
torch.save(x_T2, f'/home/yanivni/data/tmp/latent_space_editing/{image_name}/good_xT2.pt')

## Experiment - Generate video frames

In [None]:
# Sampling method 1 - generate frame0 from noise and then noise+denoise_with_next_frame to get the next frame
video_name = 'dutch2'

version_name = 'version_3'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{video_name}/{version_name}/checkpoints/last.ckpt'
saved_data_path = f'/home/yanivni/data/tmp/organized-outputs/'
model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=16, frame_conditioned=True), timesteps=500, strict=False).to(device='cuda:0')

size = tuple(imread(f'./images/video/{video_name}/1.png').shape[-2:])
total_frame_count = len(os.listdir(f'./images/video/{video_name}'))

t = 200
s0 = model.sample(image_size=size, batch_size=1, frame=0)
show_sample(s0)
samples = [s0]
for frame in range(1, total_frame_count + 1):
    noisy_prev_frame = noise_img(samples[-1], model, t)
    s = model.sample(custom_initial_img=noisy_prev_frame, custom_timesteps=t, frame=frame)
    show_sample(s)
    samples.append(s)

In [None]:
show_gif(torch.cat(samples, dim=0), interval=50)

In [None]:
# Sampling method 2 - Use an existing frame as frame0 and then noise+denoise_with_next_frame to get the next frame
video_name = 'birds4'

version_name = 'version_0'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{video_name}/{version_name}/checkpoints/last.ckpt'

saved_data_path = f'/home/yanivni/data/tmp/organized-outputs/'
model = ConditionalDiffusion.load_from_checkpoint(path,
                                              model=NextNet(in_channels=6, depth=16, frame_conditioned=True),
                                              timesteps=500, strict=False).to(device='cuda:0')

total_frame_count = len(os.listdir(f'./images/video/{video_name}'))
print(total_frame_count)

s0 = imread(f'./images/video/{video_name}/1.png').to(device=device) * 2 - 1
show_sample(s0)
samples = [s0]
for frame in range(1, total_frame_count + 1):
    print(frame, )
    s = model.sample(condition=samples[-1], frame=1)
    show_sample(s)
    samples.append(s)
#show_gif(torch.cat(samples, dim=0), interval=100)

In [None]:
show_gif(torch.cat(samples, dim=0), interval=100)
#torchvid2mp4(resize(torch.cat(samples, dim=0), out_shape=(3,(size[0]//2)*2,(size[1]//2)*2)).permute((1, 0, 2, 3)), f'/home/yanivni/data/tmp/video_generation/{video_name}/sample_294k.mp4')

In [None]:
gt_samples = []
#total_frame_count = 
for frame in range(1, len(os.listdir(f'./images/video/{video_name}')) + 1):
    gt_samples.append(imread(f'./images/video/{video_name}/{frame}.png').to(device=device) * 2 - 1)
show_gif(torch.cat(gt_samples, dim=0), interval=100)
torchvid2mp4(resize(torch.cat(gt_samples, dim=0), out_shape=(3,(size[0]//2)*2,(size[1]//2)*2)).permute((1, 0, 2, 3)), f'/home/yanivni/data/tmp/video_generation/{video_name}/gt.mp4')

In [None]:
from html_results.html_utils import create_results_html

base_dir = '/home/yanivni/data/tmp/organized-outputs/Video Generation/ski_slope'
html_filename = 'result.html'
vid_folders = [
#     (f'scale_{s}', [f'{s}/v', f'{s}/q', f'{s}/r'], 'green', 'flex', 30, True) for s in reversed(range(VGPNN.n_stages))
    (f'bla', [f'{i}' for i in range(1, 5)], 'green', 'flex', 30, True),
]
create_results_html(vid_folders, base_dir, html_filename, frame_rate=10)

In [None]:
# Sampling method 3 - Use an existing frame as frame0, generate every N'th frame from it and interpolate the rest of the
# frames (conditional on both 0 and N'th frame)

video_name = 'air_balloons'
version_name = '6-all-frames-frame-diff-embedding-also-negative'
interp_version_name = '7-all-frames-conditioned-on-past-and-future-frames'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{video_name}/{version_name}/checkpoints/last.ckpt'
interp_path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{video_name}/{interp_version_name}/checkpoints/last.ckpt'
saved_data_path = f'/home/yanivni/data/tmp/organized-outputs/'

model = ConditionalDiffusion.load_from_checkpoint(path, 
                                                  model=NextNet(in_channels=6, depth=16, frame_conditioned=True),
                                                  timesteps=500, strict=False).to(device='cuda:0')
interp_model = ConditionalDiffusion.load_from_checkpoint(interp_path, 
                                                         model=NextNet(in_channels=9, depth=16, frame_conditioned=True),
                                                         timesteps=500, strict=False).to(device='cuda:0')

total_frame_count = 100#len(os.listdir(f'./images/video/{video_name}'))
s0 = imread(f'./images/video/{video_name}/1.png').to(device=device) * 2 - 1
show_sample(s0)
samples = [s0]
N = 2

for frame in range(0, total_frame_count + 1, N):
    s = model.sample(condition=samples[-1], frame=N)
    condition_interp = torch.cat((samples[-1], s), dim=1)
    for interp_frame in range(1, N):
        print(frame + interp_frame)
        s_interp = interp_model.sample(condition=condition_interp, frame=(interp_frame, N - interp_frame))
        show_sample(s_interp)
        samples.append(s_interp)
    print(frame + N)
    show_sample(s)
    samples.append(s)

In [None]:
show_gif(torch.cat(samples, dim=0), interval=100)

## Experiment - Perform DG based on VS "pyramid"

In [None]:
image_name = 'balloons.png'
version_name = '11-simple-diffusion-64-crops-nextnet-depth-11-btlnk-for-vs'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/single-level-step=199999.ckpt'
model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=11, filters_per_layer=[32, 64, 64, 128, 128, 256, 128, 128, 64, 64, 32]), timesteps=500, strict=False).to(device='cuda:0')
    
img = imread(f'./images/{image_name}').to(device=device) * 2 - 1 


initial_size = (64, 64)
final_size = (186, 240)
n_iterations = 15
scale = ((final_size[0] / initial_size[0]) ** (1 / n_iterations), (final_size[1] / initial_size[1]) ** (1 / n_iterations))

for t in [250]:
    print(f't={t}')
    img = model.sample(image_size=initial_size, batch_size=1)
    show_sample(img)
    
    for i in range(n_iterations):
        print(i)
        resized_img = resize(img, scale_factors=scale)
        show_sample(resized_img)
        print(img.shape)
        noisy_img = noise_img(resized_img, model, t)
        show_sample(noisy_img)
        img = model.sample(custom_initial_img=noisy_img, custom_timesteps=t)

## Experiment - Collage as VS

In [None]:
image_name = 'man.jpg-woman.jpg'
version_name = '2-trained-on-64crops-for-visual-summary-collage'
path = f'/home/yanivni/data/remote_projects/single-image-diffusion/lightning_logs/{image_name}/{version_name}/checkpoints/single-level-step=299999.ckpt'
model = Diffusion.load_from_checkpoint(path, model=NextNet(depth=16, filters_per_layer=64), timesteps=500, strict=False).to(device='cuda:0')
    
def summarize_image(img, model, n_iteration=10, scale=(0.9, 0.9), t=100):
    resized_imgs = [img]
    for i in range(n_iterations):
        resized_img = resize(resized_imgs[-1], scale_factors=scale)
        fixed_img = model.sample(custom_initial_img = noise_img(resized_img, model, t), custom_timesteps=t)
        resized_imgs.append(fixed_img)
    return resized_imgs

img = imread(f'./images/collage/man_woman_combined.png').to(device=device) * 2 - 1 
final_size = (img.shape[-2], img.shape[-1] // 2)

n_iterations = 15
scale = ((final_size[0] / img.shape[-2]) ** (1 / n_iterations), (final_size[1] / img.shape[-1]) ** (1 / n_iterations))
print(scale)

for t in [150, 250]:
    print(f't={t}')
    samples = summarize_image(img, model, n_iterations, scale, t)
    sample_directory = os.path.join(saved_data_path, 'Visual Summary', f'{image_name}/t={t}_scale={scale[0] :2f}_{scale[1] :2f}')
    os.makedirs(sample_directory, exist_ok=True)
    
    for i, s in enumerate(samples):
        print(f'scale={(scale[0] ** i, scale[1] ** i)}')
        show_sample(s, 10)
        save_diffusion_sample(s, os.path.join(sample_directory, f'scale={scale[0] ** i :2f}_{scale[1] ** i :2f}.png'))
