In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
import os, sys, math, random
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
import kornia
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from tqdm.autonotebook import tqdm
from pytorch_pretrained_biggan import BigGAN, convert_to_images, save_as_images
from pathlib import Path
from datetime import datetime
import random, json

from src.notebook_utils import imshow, imgrid, pltshow, draw_tensors
from src.face_loss import DlibFaceLoss
from src.pytorch_utils import augment
from src.palette import random_biggan, load_directory, load_images
from src.collage import Collager
from src.collage_save import CollageSaver

In [4]:
img_size = 512

In [5]:
face_loss = DlibFaceLoss(filter_index=1)

# Mask generator

In [6]:
from src.gan import Generator
mask_generator = Generator(img_size=128, latent_size=100, channels=1).cuda()
# https://drive.google.com/file/d/1IhoB6lxbKxL66F0X99ntL-t3-XKnxDPZ/view?usp=sharing
model_path = './saved_models/dcgan_gen_128'
mask_generator.load_state_dict(torch.load(model_path))
mask_generator.eval()
None

# Make or load the palette

In [10]:
img_paths = []
# for img_dir in ('./datasets/ab_biggan/', './datasets/sci-bio-art/', './datasets/eyes_closed/'):
#     img_paths += [ os.path.join(img_dir, n) for n in os.listdir(img_dir) ] 

img_paths = [
    './datasets/ab_biggan/2db513d411406270f1ee_hires.jpeg',
    './datasets/ab_biggan/2304b8bce5b78c75893e_hires.jpeg',
]
    
all_palette_imgs_large = load_images(img_paths, 1024)
all_palette_imgs = F.interpolate(
    all_palette_imgs_large,
    size=(img_size, img_size),
    mode='bilinear'
)
# print(len(all_palette_imgs_large))

# Optimization

In [None]:
n_steps=600
lr=2e-2

while True:
#     n_refs = random.randint(1, 2)
    n_refs = random.choice((1, 2, 3, 4))
    indices = random.sample(range(len(all_palette_imgs_large)), n_refs)
    patch_per_img = random.randint(8, 28) // n_refs
    palette_imgs_large = all_palette_imgs_large[indices]
    palette_imgs = all_palette_imgs[indices]
    collager = Collager(palette_imgs, mask_generator, img_size, patch_per_img)
    
    draw_tensors(F.interpolate(palette_imgs, size=(200, 200)))
    
    frames = []
    collage_data = collager.makeRandom(trans_scale=.2)
    params = collage_data
    Z = collage_data[0]

    for x in collage_data:
        x.requires_grad_(True)

    opt = torch.optim.Adam(params, lr=lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=lr, total_steps=n_steps)

    pbar = tqdm(total=n_steps)
    loss_history = []

    for i in range(n_steps):
        percent = i / n_steps
        pbar.update()        
        opt.zero_grad()
        fl = torch.zeros(1)
        norm_loss = .25 * Z.norm()
        img, _ = collager(*collage_data)
        aug = augment(img, n=3)
        fl = face_loss(((aug+1)*.5)).mean()
        loss = fl + norm_loss - .01*img.mean()
        loss_history.append(loss.detach().cpu().item())
        loss.backward(retain_graph=True)
        opt.step()
        scheduler.step()
        pbar.set_description(f"fl: {fl.item():.3f}")
        frames.append(
            np.array(convert_to_images(img.detach().cpu())[0])
        )
    # Export results
    saver = CollageSaver()
    saver.save_palette(palette_imgs)
    print(saver.path)
    saver.save_video(frames)

    draw_tensors(img)

    export_collager = Collager(palette_imgs_large, mask_generator, 1024, patch_per_img)
    with torch.no_grad():
        hires, data = export_collager(*collage_data, return_data=True) 
        saver.save(hires, data, final=True)

    with open(saver.path / 'image_names.txt', 'w') as outfile:
        json.dump([img_paths[i] for i in indices], outfile)


In [None]:
!ls ./results