In [None]:
from torch import optim



from model import Generator, Discriminator
from dataloader import get_loader
from utils import train_fn, generate_examples

from math import log2

import config

In [None]:
gen = Generator(
        config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
    ).to(config.DEVICE)


critic = Discriminator(config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
# initialize optimizers


opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=config.LEARNING_RATE, betas=(0.0, 0.99))


opt_critic = optim.Adam(
    critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
)



In [None]:

gen.train()
critic.train()



# start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # start with very low alpha
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")
    img_save_freq = num_epochs // 20
    
    for epoch in range(num_epochs):
        
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )
        if epoch % img_save_freq == 0 or epoch == (num_epochs-1):
            generate_examples(gen, step,f"{epoch}_{num_epochs}")
    step += 1  # progress to the next img size