In [1]:
# train
# train
import torch
import torch.optim as optim
import models

from dataloader import get_dataloader


In [87]:
from easydict import EasyDict
args = EasyDict({
    'model_name': 'dcgan',
    'dataset_name': 'cifar10',
    'data_root': '../data/',
    'resolution': 32,
    'classes_to_include': [0], #[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
    'batch_size': 128,
    'device': 'cuda',
    'nz': 100,
    'ngf': 64,
    'ndf': 64,
    'nc': 3,
})


In [None]:
device = args.device

tot_itr = 10000
model_classes = [m for m in dir(models) 
                   if m.lower() == args.model_name.lower()]
assert model_classes, f'Model name {args.model_name} does not exist'
components = getattr(models, model_classes[0]).components

generator = components['generator'](args).train().to(device)
discriminator = components['discriminator'](args).train().to(device)
criterion = components['criterion'](args)
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

losses_d, losses_g = [], []

train_loader = get_dataloader(args, train=True)

train_iter = iter(train_loader)

for itr in range(tot_itr):
    try:
        imgs, _ = next(train_iter)
    except StopIteration:
        train_loader = get_dataloader(args, train=True)
        train_iter = iter(train_loader)
        imgs, _ = next(train_iter)
    imgs = imgs.to(device)

    # 1. Update Discriminator

    optimizer_d.zero_grad()
    real_pred = discriminator(imgs)

    noise = torch.randn(args.batch_size, args.nz, 1, 1).to(device)
    fake_imgs = generator(noise)
    fake_pred = discriminator(fake_imgs.detach())
    loss_d = criterion((fake_pred, real_pred), mode='discriminator_loss')
    losses_d.append(loss_d.item())

    loss_d.backward()
    optimizer_d.step()

    # 2. Update Generator
    optimizer_g.zero_grad()
    fake_pred = discriminator(fake_imgs)
    loss_g = criterion((fake_pred, real_pred), mode='generator_loss')
    losses_g.append(loss_g.item())

    loss_g.backward()
    optimizer_g.step()
    
    if (itr + 1) % int(tot_itr / 10) == 0:
        print('-' * 20)
        print(f'Iteration: {itr + 1}')
        print(f'Discriminator Loss: {loss_d}')
        print(f'Generator Loss: {loss_g}')


--------------------
Iteration: 1000
Discriminator Loss: 0.4078954756259918
Generator Loss: 1.654304027557373
--------------------
Iteration: 2000
Discriminator Loss: 0.33238404989242554
Generator Loss: 1.978642225265503
--------------------
Iteration: 3000
Discriminator Loss: 0.2976579964160919
Generator Loss: 1.9868347644805908
--------------------
Iteration: 4000
Discriminator Loss: 0.28929203748703003
Generator Loss: 2.0308284759521484
--------------------
Iteration: 5000
Discriminator Loss: 0.3177807927131653
Generator Loss: 2.748940944671631
--------------------
Iteration: 6000
Discriminator Loss: 0.19926047325134277
Generator Loss: 3.305755615234375
--------------------
Iteration: 7000
Discriminator Loss: 0.17522764205932617
Generator Loss: 3.7408151626586914
--------------------
Iteration: 8000
Discriminator Loss: 0.10158485174179077
Generator Loss: 3.0705204010009766


In [None]:
import matplotlib.pyplot as plt
plt.plot(losses_d)
plt.plot(losses_g)

In [None]:
def plot_images(imgs):
    import matplotlib.pyplot as plt
    import torchvision
    imgs = torchvision.utils.make_grid(imgs)
    plt.imshow(imgs.permute(1, 2, 0))

In [None]:
generator.eval().cpu()
with torch.no_grad():
    imgs = generator(torch.randn(1, args.nz, 1, 1))
plot_images(imgs)

In [75]:
import time
import os
from datetime import datetime


discriminator.cpu()
generator.train()
ckpt = {
    'generator': generator.state_dict(),
    'discriminator': discriminator.state_dict()
}

torch.save(ckpt, f'checkpoints/{args.model_name}.{datetime.now().strftime('%d%m%y.%H%M%S')}.pt')
print(os.walk())

SyntaxError: invalid syntax (<ipython-input-75-5e178223a5ec>, line 13)

In [2]:
import torchvision
from torch.utils.data import DataLoader

dataset = torchvision.datasets.MNIST('../data/')
dataloader = DataLoader(dataset)
next(iter(dataloader))

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>