In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
transform = transforms.Compose([
    transforms.ToTensor()])
    

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

In [4]:
orig_data = datasets.MNIST('./data/', train=True, download=True, transform=transform)

In [5]:
dloader = DataLoader(dataset=orig_data, batch_size=10)

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.l1 = nn.Linear(in_features=128, out_features=512)
        self.l2 = nn.Linear(in_features=512, out_features=784)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        # Expecting x to be of 128 * 1
        x = self.l1(x)
        x = self.relu(x)
        x = self.l2(x)
        x = self.sigmoid(x)
        return x

In [7]:
def getZ(device):
    var = torch.rand((10, 128))
    var = var.view(10, 1, 128)
    var = var.to(device)
    return var

In [8]:
var = getZ(device)

In [9]:
gen = Generator()

In [10]:
gen = gen.to(device)

In [11]:
x = gen(var)

In [None]:
x.size()

torch.Size([10, 1, 784])

In [None]:
class Disc(nn.Module):
    def __init__(self):
        super(Disc, self).__init__()
        self.l1 = nn.Linear(in_features=784, out_features=256)
        self.l2 = nn.Linear(in_features=256, out_features=1)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l2(x)
        x = self.sigmoid(x)
        return x

In [None]:
disc = Disc()

In [None]:
disc = disc.to(device)

In [None]:
op = disc(x)

In [None]:
op.size()

torch.Size([10, 1, 1])

In [None]:
discLoss = nn.BCELoss()

In [None]:
genLoss = nn.BCELoss()

In [None]:
gen = gen.to(device)

In [None]:
disc = disc.to(device)

In [None]:
dOptim = torch.optim.Adam(disc.parameters(), lr=0.0002)
gOptim = torch.optim.Adam(gen.parameters(), lr=0.0002)

In [None]:
def getLabels(bsize, device, orig=True):
    op = None
    if orig:
        op = torch.ones(bsize, 1)
    else:
        op = torch.zeros(bsize, 1)
    op = op.to(device)
    return op

In [None]:
epochs = 108

In [None]:
posLabel = getLabels(10, device)
fakeLabel = getLabels(10, device, False)

In [None]:
for e in range(epochs):
    for batch, _ in dloader:
        batch = batch.to(device)
        batch = batch.view(10, 1, -1)
        
        # Train Discriminator
        
        dOptim.zero_grad()
        
        z = getZ(device)
        z = gen(z)
        
        fakeD = disc(z)
        origD = disc(batch)
        
        dFakeLoss = discLoss(fakeD, fakeLabel)
        dOrigLoss = discLoss(origD, posLabel)
        
        dLoss = -(dFakeLoss + dOrigLoss)
        dLoss.backward()
        
        # Train Generator
        
        gOptim.zero_grad()
        z = getZ(device)
        fakeD_gen = gen(z)
        
        disc_op = disc(fakeD_gen)
        gLoss = genLoss(fakeD, posLabel)
        
        gOptim.step()

  "Please ensure they have the same size.".format(target.size(), input.size()))
