This notebook is designed to transform a pretrained Stylegan2 model to work with conidtional labels.

It simply initialises a generator network with the right parameters and copies over the prertained parameters which both networks have in common.

In other words, you can take the unconditional FFHQ512 pretrained model, then create a conditional version of the same model ready for effective transfer learning with coniditionallabels.

In [None]:
%load_ext autoreload

%autoreload 2


import pandas as pd
from pathlib import Path

import torch
import torchvision
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm

import numpy as np
from google.colab import drive
import json
import copy
import pickle

drive.mount("/content/drive")
torch.cuda.device_count()

In [None]:
def transfer_weights(model, model_pretrained):
    common = []
    common_diff_shape = []
    other = []
    state_dict = model.state_dict()
    state_dict_pretrained = model_pretrained.state_dict()
    new_model = copy.deepcopy(model)
    new_state_dict = {}

    for key, value in state_dict.items():
        if key in state_dict_pretrained.keys():
            if value.shape == state_dict_pretrained[key].shape:
                new_state_dict[key] = state_dict_pretrained[key]
                common.append(key)
            else:
                new_state_dict[key] = value
                common_diff_shape.append(key)
        else:
            new_state_dict[key] = value
            other.append(key)

    new_model.load_state_dict(new_state_dict)
    return new_model, {
        "common": common,
        "common_diff_shape": common_diff_shape,
        "other": other,
    }

In [None]:
%cd /content/
!git clone https://github.com/snakch/stylegan2-ada-pytorch.git

In [None]:
# Set the path to the original stylegan model

!wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl
original_model_path = "/content/ffhq-res512-mirror-stylegan2-noaug.pkl"

In [None]:
%cd stylegan2-ada-pytorch
import dnnlib
import legacy

# Load the pretraine dmodel
with dnnlib.util.open_url(original_model_path) as f:
    network = legacy.load_network_pkl(f)
G_pretrained = network["G"]
D_pretrained = network["D"]
G_ema_pretrained = network["G_ema"]

In [None]:
# Set the dimensionality of the labels and resolution of the netweok
# You may have to modify some of the other parameters if dealing with a 1024 network for intance.
C_DIM = 7
RESOLUTION = 512

# Change this to a folder in your drive folder if you want the model to persist
OUTPUT_PATH = "/content/cond_model.pkl"

In [None]:
G_kwargs = {
    "class_name": "training.networks.Generator",
    "z_dim": 512,
    "w_dim": 512,
    "mapping_kwargs": {"num_layers": 8},
    "synthesis_kwargs": {
        "channel_base": 32768,
        "channel_max": 512,
        "num_fp16_res": 4,
        "conv_clamp": 256,
    },
}
D_kwargs = {
    "class_name": "training.networks.Discriminator",
    "block_kwargs": {},
    "mapping_kwargs": {},
    "epilogue_kwargs": {"mbstd_group_size": 8},
    "channel_base": 32768,
    "channel_max": 512,
    "num_fp16_res": 4,
    "conv_clamp": 256,
}
common_kwargs = {"c_dim": C_DIM, "img_resolution": RESOLUTION, "img_channels": 3}

G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs)
D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs)
G_ema = copy.deepcopy(G)

In [None]:
new_D, info = transfer_weights(D, D_pretrained)
print(
    f'{len(info["common"])} common modules, {len(info["common_diff_shape"])} differently shaped modules {len(info["other"])} other modules'
)
new_G, info = transfer_weights(G, G_pretrained)
print(
    f'{len(info["common"])} common modules, {len(info["common_diff_shape"])} differently shaped modules {len(info["other"])} other modules'
)
new_G_ema, info = transfer_weights(G_ema, G_ema_pretrained)
print(
    f'{len(info["common"])} common modules, {len(info["common_diff_shape"])} differently shaped modules {len(info["other"])} other modules'
)


new_network = {}
new_network["training_set_kwargs"] = network["training_set_kwargs"]
new_network["augment_pipe"] = network["augment_pipe"]
new_network["G"] = new_G
new_network["D"] = new_D
new_network["G_ema"] = new_G_ema

In [None]:
with open(OUTPUT_PATH, "wb") as f:
    pickle.dump(new_network, f)