<a href="https://colab.research.google.com/github/ykitaguchi77/GAN/blob/master/FixNoise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**FixNoise**

StyleGANのドメイン変換度合いを調整する

Implementation: http://cedro3.com/ai/fixnoise/

GitHub: https://github.com/LeeDongYeun/FixNoise, (Japanese_ver) https://github.com/cedro3/FixNoise.git

Paper: https://arxiv.org/pdf/2204.14079.pdf


In [22]:
#@title **1.セットアップ**

# get code from github
! git clone https://github.com/cedro3/FixNoise.git
%cd FixNoise

# install library
! pip install legacy pyspng ninja imageio-ffmpeg==0.4.3 lpips

# download pretrained_models
! mkdir pretrained
import gdown
gdown.download('https://drive.google.com/uc?id=1YHa_g5xC_VM5MbHsr3VSfco1_PX1sRkA', 'pretrained/wikiart-fm0.05-004032.pkl', quiet=False)
gdown.download('https://drive.google.com/uc?id=1Eo4T9KjkzRYdnENXgTpqIUOvaY4-SDeD', 'pretrained/metfaces-fm0.05-001612.pkl', quiet=False)
gdown.download('https://drive.google.com/uc?id=1GzM3icWaSOSGcKfYoidjEaloqc_MyAxX', 'pretrained/aahq-fm0.05-010886.pkl', quiet=False)

# import library
from torchvision.utils import make_grid
import os
import torch
import PIL.Image
import imageio
import numpy as np
#from IPython.display import Video
from IPython.core.display import Video
#from legacy import load_network
from legacy import *

# inital setting
c_dim = 0
img_resolution = 256
img_channels = 3

# difine function
def generate_blended_img(G_s, G_t, z=None, blend_weights=[0,0.25,0.5,0.75,1], truncation_psi=0.7, truncation_cutoff=8):
    all_images = []
    
    #1*512の乱数を作成
    if z == None:
        z = torch.randn([1,512]).cuda()
    assert z.shape == torch.Size([1, 512]) #テンソルサイズが違ったらエラーが出る
    
    c = torch.zeros(1,0).cuda() #size(1,0)の空行列

    #source, 
    img = G_s(z, c, truncation_psi, truncation_cutoff, noise_mode='const')
    all_images.append(img)

    for weight in blend_weights:
        img = G_t(z, c, truncation_psi, truncation_cutoff, noise_mode='interpolate', blend_weight=weight)
        all_images.append(img)

    all_images = torch.cat(all_images)
    images = make_grid(all_images, nrow=len(blend_weights)+1, padding=5, pad_value=0.99999)
    images = (images.permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
    images = PIL.Image.fromarray(images, 'RGB')
    return images


from IPython.display import display, HTML

def display_mp4(path):
    from base64 import b64encode
    mp4 = open(path,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    display(HTML("""
    <video width=700 controls>
        <source src="%s" type="video/mp4">
    </video>
    """ % data_url))

Cloning into 'FixNoise'...
remote: Enumerating objects: 166, done.[K
remote: Counting objects:   5% (1/19)[Kremote: Counting objects:  10% (2/19)[Kremote: Counting objects:  15% (3/19)[Kremote: Counting objects:  21% (4/19)[Kremote: Counting objects:  26% (5/19)[Kremote: Counting objects:  31% (6/19)[Kremote: Counting objects:  36% (7/19)[Kremote: Counting objects:  42% (8/19)[Kremote: Counting objects:  47% (9/19)[Kremote: Counting objects:  52% (10/19)[Kremote: Counting objects:  57% (11/19)[Kremote: Counting objects:  63% (12/19)[Kremote: Counting objects:  68% (13/19)[Kremote: Counting objects:  73% (14/19)[Kremote: Counting objects:  78% (15/19)[Kremote: Counting objects:  84% (16/19)[Kremote: Counting objects:  89% (17/19)[Kremote: Counting objects:  94% (18/19)[Kremote: Counting objects: 100% (19/19)[Kremote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 166 (delta 5), reused 13 (

Downloading...
From: https://drive.google.com/uc?id=1YHa_g5xC_VM5MbHsr3VSfco1_PX1sRkA
To: /content/FixNoise/FixNoise/FixNoise/FixNoise/FixNoise/FixNoise/pretrained/wikiart-fm0.05-004032.pkl
100%|██████████| 357M/357M [00:00<00:00, 374MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Eo4T9KjkzRYdnENXgTpqIUOvaY4-SDeD
To: /content/FixNoise/FixNoise/FixNoise/FixNoise/FixNoise/FixNoise/pretrained/metfaces-fm0.05-001612.pkl
100%|██████████| 296M/296M [00:00<00:00, 441MB/s]
Downloading...
From: https://drive.google.com/uc?id=1GzM3icWaSOSGcKfYoidjEaloqc_MyAxX
To: /content/FixNoise/FixNoise/FixNoise/FixNoise/FixNoise/FixNoise/pretrained/aahq-fm0.05-010886.pkl
100%|██████████| 296M/296M [00:00<00:00, 348MB/s]


In [16]:
#@title **2.セレクト・モデル**

target_dataset = 'metfaces' #@param ['metfaces', 'aahq', 'wikiart']

if target_dataset == 'metfaces':
  cfg = 'paper256'
  source_pkl = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl'
  target_pkl = 'pretrained/metfaces-fm0.05-001612.pkl'

if target_dataset == 'aahq':
  cfg = 'paper256'
  source_pkl = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl'
  target_pkl = 'pretrained/aahq-fm0.05-010886.pkl'

if target_dataset == 'wikiart':
  cfg = 'stylegan2'
  source_pkl = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-church-config-f.pkl'
  target_pkl = 'pretrained/wikiart-fm0.05-004032.pkl'

# load_networkはlegacy.py内に定義されている。
# c_dim = 0
# img_resolution = 256
# img_channels = 3
G_s = load_network(cfg, source_pkl, img_resolution, img_channels, c_dim).cuda()
G_t = load_network(cfg, target_pkl, img_resolution, img_channels, c_dim).cuda()

Loading networks from "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl"
Loading networks from "pretrained/metfaces-fm0.05-001612.pkl"


In [None]:
#@title **3.補間画像**
generate_blended_img(G_s, G_t)

In [None]:
#@title **4.補間動画**
num_step = 201
truncation_psi = 0.7 
truncation_cutoff = 8

blend_weights = np.linspace(0,1,num_step)

outdir = 'results'
os.makedirs(outdir, exist_ok=True)
video = imageio.get_writer(f'{outdir}/noise_interpolation_{target_dataset}00.mp4', mode='I', fps=50, codec='libx264', bitrate='16M')

z = torch.randn([1,512]).cuda()
c = torch.zeros(1,0).cuda()

img_source = G_s(z, c, truncation_psi, truncation_cutoff, noise_mode='const')

for weight in blend_weights:
    img = G_t(z, c, truncation_psi, truncation_cutoff, noise_mode='interpolate', blend_weight=weight)
   
    all_images = torch.cat([img_source, img])
    images = make_grid(all_images, nrow=2, padding=0)
    images = (images.permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
    video.append_data(images)
video.close()

display_mp4('results/noise_interpolation_'+target_dataset+'00.mp4')