In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from model import Generator, Discriminator, weights_init
from dataset import MyDataset
from utils import print_image, make_video, get_grid_image

from tqdm.notebook import tqdm

In [None]:
batch_size = 128
z_size = 100
out_chnl = 3
in_chnl = 3
d_chnl = 32
g_chnl = 32
lr = 0.0005

dataloader = DataLoader(MyDataset(), batch_size=batch_size, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

G = Generator(z_size, g_chnl, out_chnl).to(device)
D = Discriminator(d_chnl, in_chnl).to(device)

G.apply(weights_init)
D.apply(weights_init)

opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5,0.999))

fixed = torch.normal(0,1,size=(100,z_size,1,1)).to(device)

In [None]:
img_list = []
num_epochs = 500

for epoch in tqdm(range(num_epochs)):
  
  for i, imgs in enumerate(dataloader):
    imgs = imgs.to(device)
    real_label = torch.FloatTensor(imgs.size(0),1).fill_(1.0).to(device)
    fake_label = torch.FloatTensor(imgs.size(0),1).fill_(0.0).to(device)
    
    opt_G.zero_grad()
    z = torch.normal(0,1,size=(imgs.size(0), z_size,1,1)).to(device)
    fake_imgs = G(z)
    g_loss = F.binary_cross_entropy(D(fake_imgs), real_label)
    g_loss.backward()
    opt_G.step()

    opt_D.zero_grad()
    real_loss = F.binary_cross_entropy(D(imgs), real_label)
    fake_loss = F.binary_cross_entropy(D(fake_imgs.detach()), fake_label)
    d_loss = (real_loss+fake_loss)/2
    d_loss.backward()
    opt_D.step()

  # print loss each epoch
  print(f'[Epoch {epoch+1:3d}/{num_epochs:3d}] [G_loss {g_loss.item():2.4f}] [D_loss {d_loss.item():2.4f}]')
  # save fixed image
  img_list.append(get_grid_image(G(fixed)))
  # print sample images
  if epoch % 20 == 0:
    print_image(G(fixed))