<a href="https://colab.research.google.com/github/tiffanymoran/AlbumGAN/blob/main/StyleGAN3_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# StyleGAN3 Training

**Notes**
The code for this colab notebook is heavily modeled after Shyam BV's [colab notebook](https://colab.research.google.com/drive/1Nal3M-wjv6BeIgyTgvxhaccPFGEP6cbk?usp=sharing) and edited to train on the custom album art dataset.

# Setup

In [None]:
!pip install einops ninja gdown

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting ninja
  Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)
[K     |████████████████████████████████| 108 kB 6.0 MB/s 
Installing collected packages: ninja, einops
Successfully installed einops-0.4.1 ninja-1.10.2.3


In [None]:
#Uninstall new JAX
!pip uninstall jax jaxlib -y
#GPU frontend
!pip install "jax[cuda11_cudnn805]==0.3.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#CPU frontend
#!pip install jax[cpu]==0.3.10
#Downgrade Pytorch
!pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/torch_stable.html

Found existing installation: jax 0.3.14
Uninstalling jax-0.3.14:
  Successfully uninstalled jax-0.3.14
Found existing installation: jaxlib 0.3.14+cuda11.cudnn805
Uninstalling jaxlib-0.3.14+cuda11.cudnn805:
  Successfully uninstalled jaxlib-0.3.14+cuda11.cudnn805
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda11_cudnn805]==0.3.10
  Downloading jax-0.3.10.tar.gz (939 kB)
[K     |████████████████████████████████| 939 kB 3.9 MB/s 
Collecting jaxlib==0.3.10+cuda11.cudnn805
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.10%2Bcuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl (175.7 MB)
[K     |████████████████████████████████| 175.7 MB 4.9 kB/s 
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Created wheel for jax: filename=jax-0.3.10-py3-none-any.whl si

In [None]:
# Connect Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
if os.path.isdir('/content/drive/MyDrive/WIP/stylegan3/'):
    %cd '/content/drive/MyDrive/WIP/stylegan3/'
else:
    !git clone https://github.com/bvshyam/stylegan3.git /content/drive/MyDrive/WIP/stylegan3/
    %cd '/content/drive/MyDrive/WIP/stylegan3/'


# Custom Dataset

Execute dataset_tool.py and provide the source and destination of the input dataset to format it correctly for StyleGAN3. Also mention the resolution required.


In [None]:
zip_path = "/content/drive/MyDrive/images.zip"

!unzip {zip_path} -d /content/

Archive:  /content/drive/MyDrive/images.zip
   creating: /content/images/
  inflating: /content/__MACOSX/._images  
  inflating: /content/images/I like it when you sleep.jpeg  
  inflating: /content/__MACOSX/images/._I like it when you sleep.jpeg  
  inflating: /content/images/Mellon Collie and the Infinite Sadness.jpeg  
  inflating: /content/__MACOSX/images/._Mellon Collie and the Infinite Sadness.jpeg  
  inflating: /content/images/Alvvays.jpeg  
  inflating: /content/__MACOSX/images/._Alvvays.jpeg  
  inflating: /content/images/The Downward Spiral.jpeg  
  inflating: /content/__MACOSX/images/._The Downward Spiral.jpeg  
  inflating: /content/images/Aquatic Flowers.jpeg  
  inflating: /content/__MACOSX/images/._Aquatic Flowers.jpeg  
  inflating: /content/images/Smile - Brian Wilson.jpeg  
  inflating: /content/__MACOSX/images/._Smile - Brian Wilson.jpeg  
  inflating: /content/images/Great Big Wild Oak.jpeg  
  inflating: /content/__MACOSX/images/._Great Big Wild Oak.jpeg  
  infla

In [None]:
!python dataset_tool.py --source=/content/images --dest=/content/drive/MyDrive/WIP/stylegan3/datasets/images.zip --resolution='512x512'

python3: can't open file 'dataset_tool.py': [Errno 2] No such file or directory


# Model training

You can start from a pre-trained model. Below are some of the models from Nvdia. I originally started out with the stylegan3-r-ffhqu-256x256.pkl and then resumed from the most recen pkl file.



```
stylegan3-t-ffhq-1024x1024.pkl, stylegan3-t-ffhqu-1024x1024.pkl, stylegan3-t-ffhqu-256x256.pkl
stylegan3-r-ffhq-1024x1024.pkl, stylegan3-r-ffhqu-1024x1024.pkl, stylegan3-r-ffhqu-256x256.pkl
stylegan3-t-metfaces-1024x1024.pkl, stylegan3-t-metfacesu-1024x1024.pkl
stylegan3-r-metfaces-1024x1024.pkl, stylegan3-r-metfacesu-1024x1024.pkl
stylegan3-t-afhqv2-512x512.pkl
stylegan3-r-afhqv2-512x512.pkl
```



## Using pre-trainined model

In [None]:

# Fine-tune StyleGAN3-R for the custom album cover dataset using 1 GPU, starting from a previous pkl file.
!python /content/drive/MyDrive/WIP/stylegan3/train.py --outdir=~/training-runs --cfg=stylegan2 --data=/content/drive/MyDrive/WIP/stylegan3/datasets/albumcovers.zip \
    --gpus=1 --batch=32 --batch-gpu=16 --gamma=6.6 --mirror=1 --kimg=500 --metrics=none --snap=2\
    --resume=/content/drive/MyDrive/WIP/stylegan3/final-network-snapshot-000080.pkl



Training options:
{
  "G_kwargs": {
    "class_name": "training.networks_stylegan2.Generator",
    "z_dim": 512,
    "w_dim": 512,
    "mapping_kwargs": {
      "num_layers": 8
    },
    "channel_base": 32768,
    "channel_max": 512,
    "fused_modconv_default": "inference_only"
  },
  "D_kwargs": {
    "class_name": "training.networks_stylegan2.Discriminator",
    "block_kwargs": {
      "freeze_layers": 0
    },
    "mapping_kwargs": {},
    "epilogue_kwargs": {
      "mbstd_group_size": 4
    },
    "channel_base": 32768,
    "channel_max": 512
  },
  "G_opt_kwargs": {
    "class_name": "torch.optim.Adam",
    "betas": [
      0,
      0.99
    ],
    "eps": 1e-08,
    "lr": 0.002
  },
  "D_opt_kwargs": {
    "class_name": "torch.optim.Adam",
    "betas": [
      0,
      0.99
    ],
    "eps": 1e-08,
    "lr": 0.002
  },
  "loss_kwargs": {
    "class_name": "training.loss.StyleGAN2Loss",
    "r1_gamma": 6.6,
    "style_mixing_prob": 0.9,
    "pl_weight": 2,
    "pl_no_weight_grad