In [1]:
!git clone https://github.com/NVlabs/stylegan3.git

Cloning into 'stylegan3'...
remote: Enumerating objects: 207, done.[K
remote: Total 207 (delta 0), reused 0 (delta 0), pack-reused 207[K
Receiving objects: 100% (207/207), 4.16 MiB | 6.53 MiB/s, done.
Resolving deltas: 100% (101/101), done.


In [2]:
!git clone https://github.com/tspyrosk/sefa.git
!cp -r ./stylegan3/torch_utils ./sefa
!cp -r ./stylegan3/dnnlib ./sefa

Cloning into 'sefa'...
remote: Enumerating objects: 132, done.[K
remote: Counting objects: 100% (41/41), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 132 (delta 33), reused 23 (delta 23), pack-reused 91[K
Receiving objects: 100% (132/132), 47.51 MiB | 23.24 MiB/s, done.
Resolving deltas: 100% (54/54), done.


In [3]:
import sys
sys.path.insert(0, './stylegan3')
import legacy
import torch
import pickle
import os

os.makedirs('stylegan_pth_model', exist_ok=True)

GANS = [
    "/kaggle/input/nework-double-print-final/network-snapshot-000040.pkl",
    "/kaggle/input/network-good-80iters/network-good-80iters.pkl",
    "/kaggle/input/network-interrupted-final/network-snapshot-interrupted.pkl"
]

LOADED_GANS = [
    'stylegan_pth_model/stylegan3_shavershelldoubleprint256.pth',
    'stylegan_pth_model/stylegan3_shavershellgood256.pth',
    'stylegan_pth_model/stylegan3_shavershellinterrupted256.pth'
]

GAN_NAMES = [g.split("/")[1].split(".")[0] for g in LOADED_GANS]
    
def load_gan(class_idx):
    i = class_idx
    out_dict = {}
    with open(GANS[i], 'rb') as pickle_file:
        stylegan3 = pickle.load(pickle_file)
        out_dict['generator_smooth'] = stylegan3['G_ema'].state_dict()
        out_dict['generator'] = stylegan3['G'].state_dict()
        out_dict['discriminator'] = stylegan3['D'].state_dict()
        torch.save(out_dict, LOADED_GANS[i])
    return stylegan3

In [4]:
GAN_NAMES

['stylegan3_shavershelldoubleprint256',
 'stylegan3_shavershellgood256',
 'stylegan3_shavershellinterrupted256']

In [5]:
for g in range(len(GANS)):
    load_gan(g)

In [6]:
!add-apt-repository ppa:ubuntu-toolchain-r/test -y

#!apt-get update

!apt-get install gcc-4.9

!apt-get upgrade -y libstdc++6

Get:1 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]
Get:2 http://packages.cloud.google.com/apt gcsfuse-focal InRelease [5389 B]
Hit:3 http://archive.ubuntu.com/ubuntu focal InRelease
Get:4 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
Get:5 https://packages.cloud.google.com/apt cloud-sdk InRelease [6751 B]
Get:6 https://packages.cloud.google.com/apt google-fast-socket InRelease [5405 B]
Get:7 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]
Get:8 http://ppa.launchpad.net/ubuntu-toolchain-r/test/ubuntu focal InRelease [17.5 kB]
Get:9 http://packages.cloud.google.com/apt gcsfuse-focal/main amd64 Packages [1578 B]
Get:10 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  InRelease [1581 B]
Get:11 https://packages.cloud.google.com/apt cloud-sdk/main amd64 Packages [361 kB]
Get:12 https://packages.cloud.google.com/apt google-fast-socket/main amd64 Packages [429 B]
Get:13 http://security.ubuntu.

In [7]:
!pip install Ninja

[0m

In [8]:
# !rm -r sefa
# !rm -r checkpoints
# !rm -r results

In [9]:
import os
import argparse
from tqdm import tqdm
import numpy as np

import torch

sys.path.insert(0, './sefa')
from models import parse_gan_type
from utils import to_tensor
from utils import postprocess
from utils import load_generator
from utils import factorize_weight

from PIL import Image

In [10]:
GLOBAL_SEED = 442
SAVE_DIR = 'sefa_pcl_shaver_aug_images'

TOTAL_GEN = 2500
MULT_COEFF = 4
NUM_SEMANTICS = 10
NUM_SAMPLES = (MULT_COEFF*TOTAL_GEN)//(len(GANS)*NUM_SEMANTICS)

TRUNC_PSI = 0.9
START_DIST = -15.0
END_DIST = 15.0
STEP = 11
LAYERS = '2,3,4,5'

In [11]:
os.makedirs(SAVE_DIR, exist_ok=True)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

for i in range(3):
    # Factorize weights.
    generator = load_generator(GAN_NAMES[i])
    gan_type = parse_gan_type(generator)
    layers, boundaries, values = factorize_weight(generator, LAYERS)

    # Set random seed.
    np.random.seed(GLOBAL_SEED)
    torch.manual_seed(GLOBAL_SEED)

    # Prepare codes.
    codes = torch.randn(NUM_SAMPLES, generator.z_space_dim).cuda()
    
    codes = generator.mapping(codes, torch.empty(0, 3), truncation_psi=TRUNC_PSI) 
    codes = codes.detach().cpu().numpy()

    # Generate visualization pages.
    distances = np.linspace(START_DIST, END_DIST, STEP)
    num_sam = NUM_SAMPLES
    num_sem = NUM_SEMANTICS
    
    for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
        code = codes[sam_id:sam_id + 1]
        for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
            boundary = boundaries[sem_id:sem_id + 1]
            for col_id, d in enumerate(distances, start=1):
                temp_code = code.copy()
                
                temp_code[:, layers, :] += boundary * d
                image = generator.synthesis(to_tensor(temp_code))
                
                image = postprocess(image)[0]
                
                im = Image.fromarray(image)
                im.save(f'{SAVE_DIR}/im_{i}_{sam_id}_{sem_id}_{int(round(d))}.png')    

Building generator for model `stylegan3_shavershelldoubleprint256` ...
Finish building generator.
Loading checkpoint from `checkpoints/stylegan3_shavershelldoubleprint256.pth` ...
Local Path: `./stylegan_pth_model/stylegan3_shavershelldoubleprint256.pth`
 Fetching checkpoint from local path `./stylegan_pth_model/stylegan3_shavershelldoubleprint256.pth` ...
  Finish copying to checkpoint.
Finish loading checkpoint.
Setting up PyTorch plugin "bias_act_plugin"... Done.


Sample :   0%|          | 0/333 [00:00<?, ?it/s]
Semantic :   0%|          | 0/10 [00:00<?, ?it/s][A

Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.



Semantic :  10%|█         | 1/10 [01:37<14:40, 97.82s/it][A
Semantic :  20%|██        | 2/10 [01:38<05:26, 40.75s/it][A
Semantic :  30%|███       | 3/10 [01:39<02:37, 22.51s/it][A
Semantic :  40%|████      | 4/10 [01:40<01:23, 13.92s/it][A
Semantic :  50%|█████     | 5/10 [01:40<00:45,  9.18s/it][A
Semantic :  60%|██████    | 6/10 [01:41<00:25,  6.31s/it][A
Semantic :  70%|███████   | 7/10 [01:42<00:13,  4.49s/it][A
Semantic :  80%|████████  | 8/10 [01:43<00:06,  3.30s/it][A
Semantic :  90%|█████████ | 9/10 [01:43<00:02,  2.51s/it][A
Semantic : 100%|██████████| 10/10 [01:44<00:00,  1.98s/it][A
Sample :   0%|          | 1/333 [01:44<9:39:41, 104.76s/it]
Semantic :   0%|          | 0/10 [00:00<?, ?it/s][A
Semantic :  10%|█         | 1/10 [00:00<00:06,  1.40it/s][A
Semantic :  20%|██        | 2/10 [00:01<00:06,  1.33it/s][A
Semantic :  30%|███       | 3/10 [00:02<00:05,  1.23it/s][A
Semantic :  40%|████      | 4/10 [00:03<00:04,  1.25it/s][A
Semantic :  50%|█████     | 5/1

Building generator for model `stylegan3_shavershellgood256` ...
Finish building generator.
Loading checkpoint from `checkpoints/stylegan3_shavershellgood256.pth` ...
Local Path: `./stylegan_pth_model/stylegan3_shavershellgood256.pth`
 Fetching checkpoint from local path `./stylegan_pth_model/stylegan3_shavershellgood256.pth` ...
  Finish copying to checkpoint.
Finish loading checkpoint.


Sample :   0%|          | 0/333 [00:00<?, ?it/s]
Semantic :   0%|          | 0/10 [00:00<?, ?it/s][A
Semantic :  10%|█         | 1/10 [00:00<00:07,  1.24it/s][A
Semantic :  20%|██        | 2/10 [00:01<00:06,  1.28it/s][A
Semantic :  30%|███       | 3/10 [00:02<00:05,  1.31it/s][A
Semantic :  40%|████      | 4/10 [00:03<00:04,  1.32it/s][A
Semantic :  50%|█████     | 5/10 [00:03<00:03,  1.31it/s][A
Semantic :  60%|██████    | 6/10 [00:04<00:03,  1.32it/s][A
Semantic :  70%|███████   | 7/10 [00:05<00:02,  1.32it/s][A
Semantic :  80%|████████  | 8/10 [00:06<00:01,  1.31it/s][A
Semantic :  90%|█████████ | 9/10 [00:06<00:00,  1.31it/s][A
Semantic : 100%|██████████| 10/10 [00:07<00:00,  1.31it/s][A
Sample :   0%|          | 1/333 [00:07<42:19,  7.65s/it]
Semantic :   0%|          | 0/10 [00:00<?, ?it/s][A
Semantic :  10%|█         | 1/10 [00:00<00:06,  1.39it/s][A
Semantic :  20%|██        | 2/10 [00:01<00:06,  1.33it/s][A
Semantic :  30%|███       | 3/10 [00:02<00:05,  1.32it/

Building generator for model `stylegan3_shavershellinterrupted256` ...
Finish building generator.
Loading checkpoint from `checkpoints/stylegan3_shavershellinterrupted256.pth` ...
Local Path: `./stylegan_pth_model/stylegan3_shavershellinterrupted256.pth`
 Fetching checkpoint from local path `./stylegan_pth_model/stylegan3_shavershellinterrupted256.pth` ...
  Finish copying to checkpoint.
Finish loading checkpoint.


Sample :   0%|          | 0/333 [00:00<?, ?it/s]
Semantic :   0%|          | 0/10 [00:00<?, ?it/s][A
Semantic :  10%|█         | 1/10 [00:00<00:07,  1.14it/s][A
Semantic :  20%|██        | 2/10 [00:01<00:06,  1.20it/s][A
Semantic :  30%|███       | 3/10 [00:02<00:05,  1.25it/s][A
Semantic :  40%|████      | 4/10 [00:03<00:04,  1.28it/s][A
Semantic :  50%|█████     | 5/10 [00:03<00:03,  1.30it/s][A
Semantic :  60%|██████    | 6/10 [00:04<00:03,  1.32it/s][A
Semantic :  70%|███████   | 7/10 [00:05<00:02,  1.33it/s][A
Semantic :  80%|████████  | 8/10 [00:06<00:01,  1.33it/s][A
Semantic :  90%|█████████ | 9/10 [00:06<00:00,  1.34it/s][A
Semantic : 100%|██████████| 10/10 [00:07<00:00,  1.35it/s][A
Sample :   0%|          | 1/333 [00:07<42:18,  7.65s/it]
Semantic :   0%|          | 0/10 [00:00<?, ?it/s][A
Semantic :  10%|█         | 1/10 [00:00<00:07,  1.18it/s][A
Semantic :  20%|██        | 2/10 [00:01<00:06,  1.24it/s][A
Semantic :  30%|███       | 3/10 [00:02<00:05,  1.27it/

In [12]:
image.shape

(256, 256, 3)

In [13]:
!rm -r ./sefa
!rm -r ./checkpoints
!rm -r ./stylegan3
!rm -r ./stylegan_pth_model