In [1]:
import sys
sys.path.append("models/stylegan2/")

from sgan2_enc_trainer import Trainer

import os
import fire
import random
from retry.api import retry_call
from tqdm import tqdm
from datetime import datetime
from functools import wraps
from util import NanException

import torch
import torch.multiprocessing as mp
import torch.distributed as dist

import numpy as np

from setup import proj_dir, image_dir, out_dir, model_dir

%load_ext autoreload
%autoreload 2

In [2]:
data = image_dir+"zoom15/"
results_dir = out_dir
models_dir = model_dir
stylegan_name = '2303-2'
encoder_name = '2001'

# GAN
gan_load_from = 99
# encoder
new = True
enc_load_from = -1

image_size = 64
network_capacity = 16
fmap_max = 512
transparent = False
batch_size = 2
gradient_accumulate_every = 6
num_train_steps = 150000
learning_rate = 1e-5
lr_mlp = 0.1
ttur_mult = 1.5
rel_disc_loss = False
num_workers =  None
save_every = 1000
evaluate_every = 1000
generate = False
num_generate = 1
generate_interpolation = False
interpolation_num_steps = 100
save_frames = False
num_image_tiles = 8
trunc_psi = 0.75
mixed_prob = 0.9
fp16 = False
no_pl_reg = False
cl_reg = False
fq_layers = []
fq_dict_size = 256
attn_layers = []
no_const = False
aug_prob = 0.
aug_types = ['translation', 'cutout']
top_k_training = False
generator_top_k_gamma = 0.99
generator_top_k_frac = 0.5
dual_contrast_loss = False
dataset_aug_prob = 0.
multi_gpus = True
calculate_fid_every = None
calculate_fid_num_images = 12800
clear_fid_cache = False
seed = 42
log = False

# A global scale to the custom losses
kl_scaling=1
rec_scaling=1

# If unspecified, use the Discriminator as an encoder (like the authors did).
# This is the way to go if we want to be close to the original paper.
# Check out debug_encoders.py for the names of classes if you still want
# to use a different encoder.
encoder_class='GHFeat'

kl_rec_during_disc=False

# This is for making the image results be results of the
# image -> encoder -> generator pipeline
# Set False if training a standard GAN or if you want to see
# examples from a noise vector.
sample_from_encoder=True

# Alternatively trains the model with the StylEx loss
# and the regular StyleGAN loss. If False just trains
# using the encoder.
alternating_training=False

tensorboard_dir=None  # Put to None for not logging

rank = 0

In [3]:
model_args = dict(
        stylegan_name=stylegan_name,
        encoder_name=encoder_name,
        results_dir=results_dir,
        models_dir=models_dir,
        batch_size=batch_size,
        gradient_accumulate_every=gradient_accumulate_every,
        image_size=image_size,
        network_capacity=network_capacity,
        fmap_max=fmap_max,
        transparent=transparent,
        lr=learning_rate,
        lr_mlp=lr_mlp,
        ttur_mult=ttur_mult,
        rel_disc_loss=rel_disc_loss,
        num_workers=num_workers,
        save_every=save_every,
        evaluate_every=evaluate_every,
        num_image_tiles=num_image_tiles,
        trunc_psi=trunc_psi,
        fp16=fp16,
        no_pl_reg=no_pl_reg,
        cl_reg=cl_reg,
        fq_layers=fq_layers,
        fq_dict_size=fq_dict_size,
        attn_layers=attn_layers,
        no_const=no_const,
        aug_prob=aug_prob,
        aug_types=aug_types,
        top_k_training=top_k_training,
        generator_top_k_gamma=generator_top_k_gamma,
        generator_top_k_frac=generator_top_k_frac,
        dual_contrast_loss=dual_contrast_loss,
        dataset_aug_prob=dataset_aug_prob,
        calculate_fid_every=calculate_fid_every,
        calculate_fid_num_images=calculate_fid_num_images,
        clear_fid_cache=clear_fid_cache,
        mixed_prob=mixed_prob,
        log=log,
        kl_scaling=kl_scaling,
        rec_scaling=rec_scaling,
#         classifier_path=classifier_path,
#         num_classes=num_classes,
        encoder_class=encoder_class,
        sample_from_encoder=sample_from_encoder,
        alternating_training=alternating_training,
        kl_rec_during_disc=kl_rec_during_disc,
        tensorboard_dir=tensorboard_dir,
#         classifier_name=classifier_name,
        rank=rank)

In [4]:
def encoder_training(rank, world_size, model_args, data, gan_load_from, enc_load_from, new, num_train_steps, seed):
#     is_main = rank == 0
    is_main = True
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(
        is_ddp = is_ddp,
        rank = rank,
        world_size = world_size
    )

    model = Trainer(**model_args)
    
    model.load('model', gan_load_from)
    if not new:
        model.load('enc', enc_load_from)
    else:
        model.clear()

    model.set_data_src(data)
    model.set_test_data_src(data, 8)
    
    progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'<{data}>', position=0, leave=True)
    while model.steps < num_train_steps:
        retry_call(model.train_encoder_only, tries=3, exceptions=NanException)
        progress_bar.n = model.steps
        progress_bar.refresh()
        if is_main and model.steps % 500 == 0:
            model.print_log()

    model.save(model.checkpoint_num)

    if is_ddp:
        dist.destroy_process_group()

In [None]:
rank = 0
world_size = 1

encoder_training(rank, world_size, model_args, data, gan_load_from, enc_load_from, new, num_train_steps, seed)

Results directory: /home/jtl/Dropbox (MIT)/project_image_demand/results/sGAN2/2303-2/2001
Model directory: /dreambig/qingyi/image_chicago/models/sGAN2/2303-2
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/jtl/anaconda3/envs/qingyi/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth
Loading /dreambig/qingyi/image_chicago/models/sGAN2/2303-2/model_99.pt


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   0%|          | 500/150000 [18:06<90:13:23,  2.17s/it]

G: 0.00 | D: 0.11 | GP: 0.29 | Rec: 2.23 | Rec_w: 0.39 | Rec_i: 0.19 | Rec_pips: 0.06


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   1%|          | 1000/150000 [36:11<89:52:55,  2.17s/it]

G: 0.00 | D: 0.22 | GP: 0.06 | Rec: 2.09 | Rec_w: 0.36 | Rec_i: 0.20 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   1%|          | 1500/150000 [54:17<89:34:57,  2.17s/it]

G: 0.00 | D: 0.40 | GP: 0.38 | Rec: 1.93 | Rec_w: 0.33 | Rec_i: 0.18 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   1%|▏         | 2000/150000 [1:12:22<89:16:21,  2.17s/it]

G: 0.00 | D: 0.15 | GP: 0.22 | Rec: 2.14 | Rec_w: 0.37 | Rec_i: 0.19 | Rec_pips: 0.06


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   2%|▏         | 2500/150000 [1:30:29<88:59:22,  2.17s/it]

G: 0.00 | D: 0.13 | GP: 0.20 | Rec: 1.87 | Rec_w: 0.29 | Rec_i: 0.19 | Rec_pips: 0.04


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   2%|▏         | 3000/150000 [1:48:30<88:37:13,  2.17s/it]

G: 0.00 | D: 0.22 | GP: 0.13 | Rec: 1.90 | Rec_w: 0.33 | Rec_i: 0.16 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   2%|▏         | 3500/150000 [2:06:33<88:17:01,  2.17s/it]

G: 0.00 | D: 0.26 | GP: 0.33 | Rec: 1.50 | Rec_w: 0.22 | Rec_i: 0.17 | Rec_pips: 0.04


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   3%|▎         | 4000/150000 [2:24:34<87:56:55,  2.17s/it]

G: 0.00 | D: 0.34 | GP: 0.16 | Rec: 1.51 | Rec_w: 0.21 | Rec_i: 0.17 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   3%|▎         | 4500/150000 [2:42:36<87:37:53,  2.17s/it]

G: 0.00 | D: 0.25 | GP: 0.15 | Rec: 1.65 | Rec_w: 0.21 | Rec_i: 0.21 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   3%|▎         | 5000/150000 [3:00:39<87:18:52,  2.17s/it]

G: 0.00 | D: 0.65 | GP: 0.36 | Rec: 1.26 | Rec_w: 0.14 | Rec_i: 0.19 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   4%|▎         | 5500/150000 [3:18:41<87:00:17,  2.17s/it]

G: 0.00 | D: 0.56 | GP: 0.22 | Rec: 1.28 | Rec_w: 0.16 | Rec_i: 0.17 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   4%|▍         | 6000/150000 [3:36:43<86:41:28,  2.17s/it]

G: 0.00 | D: 0.59 | GP: 0.12 | Rec: 1.30 | Rec_w: 0.16 | Rec_i: 0.18 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   4%|▍         | 6500/150000 [3:54:45<86:22:52,  2.17s/it]

G: 0.00 | D: 0.54 | GP: 0.28 | Rec: 1.33 | Rec_w: 0.14 | Rec_i: 0.18 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   5%|▍         | 7000/150000 [4:12:47<86:04:00,  2.17s/it]

G: 0.00 | D: 0.42 | GP: 0.17 | Rec: 1.28 | Rec_w: 0.15 | Rec_i: 0.17 | Rec_pips: 0.05


</dreambig/qingyi/image_chicago/data/images/satellite/zoom15/>:   5%|▍         | 7356/150000 [4:25:37<85:50:44,  2.17s/it]