In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms

from my_models import GeneratorNet, DiscriminatorNet
from helper import generator_loss, discriminator_loss, show_images

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32

trans = transforms.Compose([transforms.ToTensor()])

train_dataset = dset.MNIST('./data', train=True, transform=trans, download=True)

In [3]:
batch_size = 64
n_epochs = 10
print_every = 100
torch.manual_seed(13)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, pin_memory=True)

In [4]:
generator_model = GeneratorNet(batch_size=batch_size)
discriminator_model = DiscriminatorNet(batch_size=batch_size)

In [5]:
generator_optimizer = optim.Adam(generator_model.parameters(), lr=1e-3, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator_model.parameters(), lr=1e-3, betas=(0.5, 0.999))

generator_model.to(device)
discriminator_model.to(device)

generator_model.train()
discriminator_model.train()

for epoch in range(1,n_epochs+1):
    for idx, (data, _) in enumerate(train_loader):
        data = data.to(device, dtype=dtype)
        
        discriminator_optimizer.zero_grad()
        
        out_real = discriminator_model(data)
        
        fake_data = generator_model()
        out_fake = discriminator_model(fake_data.view(batch_size, 1, 28, 28))
        
        d_total_error = discriminator_loss(out_real, out_fake)
        d_total_error.backward()        
        discriminator_optimizer.step()
        
        generator_optimizer.zero_grad()

        fake_data = generator_model()
        out_fake = discriminator_model(fake_data.view(batch_size, 1, 28, 28))
        
        g_error = generator_loss(out_fake)
        g_error.backward()
        generator_optimizer.step()
        
        if idx % print_every == 0:
            print('Iter: {}, D: {:.4}, G:{:.4}'.format(idx,d_total_error.item(),g_error.item()))
            imgs_numpy = fake_data.data.cpu().numpy().reshape(batch_size,-1)
            show_images(imgs_numpy[0:16])
            plt.show()
            print()
        

NameError: name 'torch' is not defined