In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

from dataset import MyDataset
from model import Generator, Discriminator, weights_init
from train import train
from utils import print_image, show_interact_image, interpolate

import os
from IPython.display import display

In [None]:
batch_size = 128
z_size = 100
out_chnl = 3
in_chnl = 3
d_chnl = 32
g_chnl = 32
lr = 0.0005

dataloader = DataLoader(MyDataset(), batch_size=batch_size, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

G = Generator(z_size, g_chnl, out_chnl).to(device)
D = Discriminator(d_chnl, in_chnl).to(device)

G.apply(weights_init)
D.apply(weights_init)

opt_G = Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

fixed = torch.normal(0,1,size=(100, z_size, 1, 1)).to(device)

In [None]:
for imgs in dataloader:
    print_image(imgs)
    break

In [None]:
checkpoints =  [file for file in os.listdir("./") if file.endswith(".pt")]
if (checkpoints):
    state = torch.load(f"./{checkpoints[-1]}")
    G.load_state_dict(state["G"])
    D.load_state_dict(state["D"])
    fixed = state["fixed"]

train(G, D, opt_G, opt_D, z_size, fixed, dataloader, device, 2)

In [None]:
G.eval()
with torch.no_grad():
    z = torch.randn((100, z_size, 1, 1)).to(device)
    imgs = G(z).to(device)

show_interact_image(imgs)

In [None]:
with torch.no_grad():
    new_z = interpolate(z[6], z[50]).to(device)
    imgs = G(new_z).to(device)

show_interact_image(imgs)