In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [12]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [13]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                     std = [0.229, 0.224, 0.225])])

train_data = datasets.MNIST(root='../../../../data/', train=True, transform=transform, download=True)
test_data  = datasets.MNIST(root='../../../../data/', train=False, transform=transform, download=True)

In [14]:
batch_size = 128
train_dataloader = DataLoader(train_data, batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size, shuffle=True)

In [17]:
hidden = 256
d_noise = 100

In [18]:
def sample_z(batch_size=1, d_noise=100):
    return torch.randn(batch_size, d_noise, device=device)

In [20]:
Generator = nn.Sequential(nn.Linear(d_noise, hidden),
                          nn.ReLU(),
                          nn.Dropout(0.3),
                          nn.Linear(hidden, hidden),
                          nn.ReLU(),
                          nn.Dropout(0.3),
                          nn.Linear(hidden, 28*28),
                          nn.Tanh()
                         ).to(device)

In [21]:
z = sample_z()

In [24]:
Discriminator = nn.Sequential(nn.Linear(28*28, hidden),
                  nn.ReLU(),
                  nn.Dropout(0.3),
                  nn.Linear(hidden, hidden),
                  nn.ReLU(),
                  nn.Dropout(0.3),
                  nn.Linear(hidden, 1),
                  nn.Sigmoid()
                 ).to(device)

In [25]:
criterion = nn.BCELoss()

In [28]:
Generator.train()
Discriminator.train()
_optimizer_G = optim.Adam(Generator.parameters(), lr=0.0002)
_optimizer_D = optim.Adam(Discriminator.parameters(), lr=0.0002)

In [29]:
for batch_idx, (img, label) in enumerate(train_dataloader):
    img = img.to(device)
    label = label.to(device)
    
    _optimizer_D.zero_grad()
    
    p_real = Discriminator(img.view(-1, 28*28))
    p_fake = Discriminator(Generator(sample_z(batch_size, d_noise)))
    
    loss_real = -1*torch.log(p_real)
    loss_fake = -1*torch.log(1 - p_fake)
    loss_D = (loss_real + loss_fake).mean()
    
    loss_D.backward()
    _optimizer_D.step()
    
    _optimizer_G.zero_grad()
    
    p_fake = Discriminator(Generator(sample_z(batch_size, d_noise)))
    
    loss_G = -1*torch.log(p_fake).mean()
    
    loss_G.backward()
    _optimizer_G.step()
    
    
    

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]