# StyleGAN3 from NVIDIA

**Notes**
* To see the original code from NVIDIA [Check here](https://github.com/NVlabs/stylegan3)
* We are using a pretrained model and fine-tuning on top of it.

# Setup

# Generate images

In [10]:
import torch
import pickle
from pathlib import Path

model_path = '/Users/ted/dev/stylegan3/00022-stylegan3-r-240523_01-rose01-YFlips-6min-gpus1-batch8-gamma6.6/network-snapshot-000036.pkl'
outdir = '/Volumes/2024-May-Ted-Moore/videos/rose/stylegan/240525_02-rose01-blobs-interpolated-cos-pkl=36'

Path(outdir).mkdir()

In [11]:
import random

seeds = list(range(1000,1015))
random.shuffle(seeds)

seeds.append(seeds[0])

print(seeds)
print(len(seeds))

[1002, 1004, 1012, 1006, 1010, 1005, 1000, 1013, 1001, 1011, 1014, 1003, 1007, 1008, 1009, 1002]
16


In [12]:
import PIL.Image
from IPython.display import Image
import matplotlib.pyplot as plt
import IPython.display

"""Generate images using pretrained network pickle."""

import os
import re
from typing import List, Optional, Tuple, Union

import click
import dnnlib
import numpy as np
import PIL.Image
import torch

import legacy

device = torch.device('mps')
with open(model_path, 'rb') as f:
    G = pickle.load(f)['G_ema'].to(device)  # torch.nn.Module

#----------------------------------------------------------------------------

def parse_range(s: Union[str, List]) -> List[int]:
    '''Parse a comma separated list of numbers or ranges and return a list of ints.
    Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
    '''
    if isinstance(s, list): return s
    ranges = []
    range_re = re.compile(r'^(\d+)-(\d+)$')
    for p in s.split(','):
        m = range_re.match(p)
        if m:
            ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
        else:
            ranges.append(int(p))
    return ranges

#----------------------------------------------------------------------------

def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
    '''Parse a floating point 2-vector of syntax 'a,b'.
    Example:
        '0,1' returns (0,1)
    '''
    if isinstance(s, tuple): return s
    parts = s.split(',')
    if len(parts) == 2:
        return (float(parts[0]), float(parts[1]))
    raise ValueError(f'cannot parse 2-vector {s}')

#----------------------------------------------------------------------------

def make_transform(translate: Tuple[float,float], angle: float):
    m = np.eye(3)
    s = np.sin(angle/360.0*np.pi*2)
    c = np.cos(angle/360.0*np.pi*2)
    m[0][0] = c
    m[0][1] = s
    m[0][2] = translate[0]
    m[1][0] = -s
    m[1][1] = c
    m[1][2] = translate[1]
    return m

#----------------------------------------------------------------------------

# @click.command()
# @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
# @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
# @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
# @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
# @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
# @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
# @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
# @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
# def generate_images(
#     network_pkl: str,
#     seeds: List[int],
#     truncation_psi: float,
#     noise_mode: str,
#     outdir: str,
#     translate: Tuple[float,float],
#     rotate: float,
#     class_idx: Optional[int]
# ):
"""Generate images using pretrained network pickle.
Examples:
\b
# Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
    --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
\b
# Generate uncurated images with truncation using the MetFaces-U dataset
python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
    --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
"""

# =============== setting params manually instead of with click ==================
# network_pkl = '/content/drive/MyDrive/dev/west-rock/training/00011-stylegan3-r-speed=12x-cut6-1024x1024-gpus1-batch16-gamma50/network-snapshot-000160.pkl'
network_pkl = model_path
class_idx = None
translate = (0,0)
noise_mode = 'const'
rotate = 0
truncation_psi = 1
# ================================================================================

print('Loading networks from "%s"...' % network_pkl)

with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

# os.makedirs(outdir, exist_ok=True)

# Labels.
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
    if class_idx is None:
        raise click.ClickException('Must specify class label with --class when using a conditional network')
    label[:, class_idx] = 1
else:
    if class_idx is not None:
        print ('warn: --class=lbl ignored when running on an unconditional network')


Loading networks from "/Users/ted/dev/stylegan3/00022-stylegan3-r-240523_01-rose01-YFlips-6min-gpus1-batch8-gamma6.6/network-snapshot-000036.pkl"...


In [13]:
n_interp_min = 240 # in frames
n_interp_max = 240 # in frames
frame_counter = 0

def lerp(v1,v2,amt):
  diff = v2 - v1
  return v1 + (diff * amt)

with open(os.path.join(outdir,'_log.txt'),'w') as f:
  f.write(f'network_pkl:  {network_pkl}\n')
  f.write(f'n_interp min: {n_interp_min}\n')
  f.write(f'n_interp max: {n_interp_max}\n')
  f.write(f'seeds:        {seeds}\n')

# seeds = seeds[:4]

prev_z = z = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)) #.to(device)

for i, seed in enumerate(seeds[1:]):
    # print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
    curr_z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)) #.to(device)

    if n_interp_min < n_interp_max:
      n_interp = np.random.randint(n_interp_min,n_interp_max)
    else:
      n_interp = n_interp_min

    lin_interps = np.linspace(0,1,n_interp,False)
    cos_interps = (np.cos(lin_interps * np.pi) * -0.5) + 0.5

    interps = (lin_interps + cos_interps) * 0.5

    for j, interp_amt in enumerate(interps):

      z = lerp(prev_z,curr_z,interp_amt)
      z = np.array(z)

      # Construct an inverse rotation/translation matrix and pass to the generator.  The
      # generator expects this matrix as an inverse to avoid potentially failing numerical
      # operations in the network.
      if hasattr(G.synthesis, 'input'):
          m = make_transform(translate, rotate)
          m = np.linalg.inv(m)
          G.synthesis.input.transform.copy_(torch.from_numpy(m))

      z = torch.from_numpy(z).to(torch.float32).to(device)
      img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
      img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
      PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(os.path.join(outdir,f'{frame_counter:06d}.png'))
      print(f'image {i+1} / {len(seeds)-1} --- interp {j+1} / {len(interps)} --- (frame {frame_counter})')
      frame_counter += 1

    prev_z = curr_z

    # display:
    # plt.axis('off')
    # img = img.reshape((img.shape[1],img.shape[2],img.shape[3]))
    # plt.imshow(img.cpu())
    # plt.show()

image 1 / 15 --- interp 1 / 240 --- (frame 0)
image 1 / 15 --- interp 2 / 240 --- (frame 1)
image 1 / 15 --- interp 3 / 240 --- (frame 2)
image 1 / 15 --- interp 4 / 240 --- (frame 3)
image 1 / 15 --- interp 5 / 240 --- (frame 4)
image 1 / 15 --- interp 6 / 240 --- (frame 5)
image 1 / 15 --- interp 7 / 240 --- (frame 6)
image 1 / 15 --- interp 8 / 240 --- (frame 7)
image 1 / 15 --- interp 9 / 240 --- (frame 8)
image 1 / 15 --- interp 10 / 240 --- (frame 9)
image 1 / 15 --- interp 11 / 240 --- (frame 10)
image 1 / 15 --- interp 12 / 240 --- (frame 11)
image 1 / 15 --- interp 13 / 240 --- (frame 12)
image 1 / 15 --- interp 14 / 240 --- (frame 13)
image 1 / 15 --- interp 15 / 240 --- (frame 14)
image 1 / 15 --- interp 16 / 240 --- (frame 15)
image 1 / 15 --- interp 17 / 240 --- (frame 16)
image 1 / 15 --- interp 18 / 240 --- (frame 17)
image 1 / 15 --- interp 19 / 240 --- (frame 18)
image 1 / 15 --- interp 20 / 240 --- (frame 19)
image 1 / 15 --- interp 21 / 240 --- (frame 20)
image 1 / 1

In [19]:
import os
filename = Path(outdir).stem
# print(filename)
command = f'ffmpeg -r 30 -i "{outdir}/%06d.png" -c:v libx264 -pix_fmt yuv420p "{outdir}/../{filename}.mp4"'
print(command)
os.system(command)

ffmpeg -r 30 -i "/Volumes/2024-May-Ted-Moore/videos/rose/stylegan/240525_02-rose01-blobs-interpolated-cos-pkl=36/%06d.png" -c:v libx264 -pix_fmt yuv420p "/Volumes/2024-May-Ted-Moore/videos/rose/stylegan/240525_02-rose01-blobs-interpolated-cos-pkl=36/../240525_02-rose01-blobs-interpolated-cos-pkl=36.mp4"


ffmpeg version 7.0 Copyright (c) 2000-2024 the FFmpeg developers
  built with Apple clang version 15.0.0 (clang-1500.1.0.2.5)
  configuration: --prefix=/opt/homebrew/Cellar/ffmpeg/7.0_1 --enable-shared --enable-pthreads --enable-version3 --cc=clang --host-cflags= --host-ldflags='-Wl,-ld_classic' --enable-ffplay --enable-gnutls --enable-gpl --enable-libaom --enable-libaribb24 --enable-libbluray --enable-libdav1d --enable-libharfbuzz --enable-libjxl --enable-libmp3lame --enable-libopus --enable-librav1e --enable-librist --enable-librubberband --enable-libsnappy --enable-libsrt --enable-libssh --enable-libsvtav1 --enable-libtesseract --enable-libtheora --enable-libvidstab --enable-libvmaf --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx264 --enable-libx265 --enable-libxml2 --enable-libxvid --enable-lzma --enable-libfontconfig --enable-libfreetype --enable-frei0r --enable-libass --enable-libopencore-amrnb --enable-libopencore-amrwb --enable-libopenjpeg --enable-libspeex --

0