In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.insert(0, "../nohomers")

In [2]:
from tqdm.auto import tqdm
import torch
import json
from pathlib import Path
from lightweight_gan import Trainer
from lightweight_gan.lightweight_gan import slerp
from uuid import uuid4
from PIL import Image
import torchvision
import tempfile
from torchvision import transforms
import numpy as np
import ffmpeg
import shutil
import copy
import random

In [3]:
def get_trainer(
    data='./data',
    results_dir='./results',
    models_dir='./models',
    name='default',
    new=False,
    load_from=-1,
    image_size=256,
    optimizer='adam',
    fmap_max=512,
    transparent=False,
    batch_size=10,
    gradient_accumulate_every=4,
    num_train_steps=150000,
    learning_rate=2e-4,
    save_every=1000,
    evaluate_every=1000,
    generate=False,
    generate_interpolation=False,
    attn_res_layers=[32],
    sle_spatial=False,
    disc_output_size=1,
    antialias=False,
    interpolation_num_steps=100,
    save_frames=False,
    num_image_tiles=8, 
    trunc_psi=0.75,
    aug_prob=None,
    aug_types=['cutout', 'translation'],
    dataset_aug_prob=0.,
    multi_gpus=False,
    calculate_fid_every=None,
    seed=42,
    amp=False
):
    def cast_list(el):
        return el if isinstance(el, list) else [el]

    model_args = dict(
        name=name,
        results_dir=results_dir,
        models_dir=models_dir,
        batch_size=batch_size,
        gradient_accumulate_every=gradient_accumulate_every,
        attn_res_layers=cast_list(attn_res_layers),
        sle_spatial=sle_spatial,
        disc_output_size=disc_output_size,
        antialias=antialias,
        image_size=image_size,
        optimizer=optimizer,
        fmap_max=fmap_max,
        transparent=transparent,
        lr=learning_rate,
        save_every=save_every,
        evaluate_every=evaluate_every,
        trunc_psi=trunc_psi,
        aug_prob=aug_prob,
        aug_types=cast_list(aug_types),
        dataset_aug_prob=dataset_aug_prob,
        calculate_fid_every=calculate_fid_every,
        amp=amp
    )

    ret = Trainer(**model_args)
    ret.load(load_from)
    return ret

In [4]:
@torch.no_grad()
def generate_image_with_latents(self, num=1):
    self.GAN.eval()
    latent_dim = self.GAN.latent_dim
    image_size = self.GAN.image_size
    latents = torch.randn((num, latent_dim)).cuda(self.rank)
    generated_images = self.generate_truncated(self.GAN.GE, latents)
    return list(
        (latents[i, :], transforms.ToPILImage()(generated_images[i, :, :, :].cpu()))
        for i in range(num)
    )

@torch.no_grad()
def generate_interpolation_frames(self, latents_low, latents_high, num_frames):
    self.GAN.eval()
    num_rows = 1

    latent_dim = self.GAN.latent_dim
    image_size = self.GAN.image_size

    # latents and noise

    #latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
    #latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
    ratios = torch.linspace(0., 1., num_frames) 

    batch_size = latents_low.shape[0]
    
    ret = [list() for _ in range(batch_size)]
    
    for i, ratio in enumerate(ratios):
        if i == 0:
            interp_latents = latents_low
        elif i == len(ratios) - 1:
            interp_latents = latents_high
        else:
            interp_latents = slerp(ratio, latents_low, latents_high)
        generated_images = self.generate_truncated(self.GAN.GE, interp_latents)
        
        for i in range(batch_size):
            ret[i].append(transforms.ToPILImage()(generated_images[i, :, :, :].cpu()))
    
    return ret

def frames_to_video(frames, output_path, fps=30, bitrate="1M"):
    with tempfile.TemporaryDirectory() as td:
        for i, frame in enumerate(frames):
            frame.save(Path(td) / f"{i:06d}.jpg")
            
        (
            ffmpeg
            .input(f'{td}/*.jpg', pattern_type='glob', framerate=fps)
            .output(filename=output_path, video_bitrate=bitrate)
            .overwrite_output()
            .run()
        )


def gen_images_and_manifest(trainer, output_base_dir, num=10):
    image_output_dir = Path(output_base_dir) / "images"
    image_output_dir.mkdir(exist_ok=True)
    
    image_objects = []
    image_and_latents = list(generate_image_with_latents(trainer, num=num))
    for latent, image in image_and_latents:
        name = f"{uuid4()}.jpg"
        image.save(str(image_output_dir / name))
        image_objects.append({
            "image_name": name,
            "latent": list(float(e) for e in latent.cpu().numpy()), 
        })
        
    return image_objects


def gen_interpolation_videos(trainer, manifest, output_base_dir, per_edge=1, video_duration=3.0, video_fps=30):
    assert len(manifest) > per_edge
    
    num_frames = int(video_fps * video_duration)
    
    videos_path = Path(output_base_dir) / "videos"
    videos_path.mkdir(exist_ok=True)
    
    out_manifest = copy.deepcopy(manifest)
    for src_i, src in enumerate(out_manifest):
        dest_set = set()
        while len(dest_set) < per_edge:
            i = random.randint(0, len(manifest) - 1)
            if i != src_i:
                dest_set.add(i)
                
        transition_items = []
        
        src_latent = torch.tensor(src["latent"]).unsqueeze(0).cuda()
        for dst_i in dest_set:
            dst = out_manifest[dst_i]
            dst_latent = torch.tensor(dst["latent"]).unsqueeze(0).cuda()
            
            video_name = f"{src['image_name']}_to_{dst['image_name']}.mp4"
            
            
            # This works in batches
            video_frames = generate_interpolation_frames(
                trainer, 
                latents_low=src_latent,
                latents_high=dst_latent,
                num_frames=num_frames,
            )
            
            frames_to_video(video_frames[0], output_path=videos_path / video_name, fps=video_fps, bitrate="1M")
            
            transition_items.append({
                "dest_index": dst_i,
                "dest_name": dst["image_name"],
                "video_name": str(video_name),
            })
            
        src["transitions"] = transition_items
    return out_manifest

In [9]:
trainer = get_trainer(
    models_dir="/mnt/evo/projects/metapedia/tmp/stylegan2/models", 
    name="simpsons_bart_homer_new_cleaned_1024",
    load_from=24,
)

loading from version 0.12.4


In [10]:
output_dir = Path("/mnt/evo/projects/nohomers/tmp")
shutil.rmtree(output_dir, ignore_errors=False, onerror=None)
output_dir.mkdir(exist_ok=True)

manifest = gen_images_and_manifest(trainer, output_dir)

In [76]:
video_manifest = gen_interpolation_videos(trainer, manifest, output_dir, per_edge=2)

In [77]:
manifest_path = Path(output_dir) / "manifest.json"
with open(manifest_path, "w") as f:
    json.dump(video_manifest, f)

In [24]:
test = generate_interpolation_frames(
    trainer, 
    latents_low=image_and_latents[0][0].unsqueeze(0),
    latents_high=image_and_latents[1][0].unsqueeze(0),
    num_frames=120,
)[0]

In [169]:
frames_to_video(test, output_dir / "test.mp4", fps=30, bitrate="0.5M")

In [160]:
print((output_dir / "test.mp4").exists())

True
