In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torchvision

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # normalize to [-1, 1]
])

dataset = datasets.MNIST(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4802036.06it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 280546.32it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2801143.40it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3896610.51it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [3]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x.view(-1, 784))


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = Generator().to(device)
D = Discriminator().to(device)

loss_fn = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=2e-4)
optimizer_D = optim.Adam(D.parameters(), lr=2e-4)

epochs = 30
for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)
        
        # Labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # --------------------
        # Train Discriminator
        # --------------------
        z = torch.randn(batch_size, 100).to(device)
        fake_imgs = G(z)

        D_real = D(real_imgs)
        D_fake = D(fake_imgs.detach())
        loss_D = loss_fn(D_real, real_labels) + loss_fn(D_fake, fake_labels)

        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # ----------------
        # Train Generator
        # ----------------
        D_fake = D(fake_imgs)
        loss_G = loss_fn(D_fake, real_labels)  # Fool D → want labels = 1

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

    print(f"Epoch {epoch+1}/{epochs} | Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")


Epoch 1/30 | Loss D: 0.0192, Loss G: 5.8029
Epoch 2/30 | Loss D: 0.0720, Loss G: 7.2054
Epoch 3/30 | Loss D: 0.4671, Loss G: 5.6137
Epoch 4/30 | Loss D: 0.2893, Loss G: 3.8379
Epoch 5/30 | Loss D: 0.1067, Loss G: 4.1739
Epoch 6/30 | Loss D: 0.1209, Loss G: 5.2269
Epoch 7/30 | Loss D: 0.0534, Loss G: 6.8594
Epoch 8/30 | Loss D: 0.1024, Loss G: 5.7673
Epoch 9/30 | Loss D: 0.1736, Loss G: 4.4611
Epoch 10/30 | Loss D: 0.1244, Loss G: 7.1674
Epoch 11/30 | Loss D: 0.0548, Loss G: 7.3097
Epoch 12/30 | Loss D: 0.0777, Loss G: 7.5549
Epoch 13/30 | Loss D: 0.1175, Loss G: 7.4537
Epoch 14/30 | Loss D: 0.1628, Loss G: 3.6677
Epoch 15/30 | Loss D: 0.2820, Loss G: 5.5298
Epoch 16/30 | Loss D: 0.1627, Loss G: 7.1455
Epoch 17/30 | Loss D: 0.1411, Loss G: 5.8220
Epoch 18/30 | Loss D: 0.3167, Loss G: 4.0849
Epoch 19/30 | Loss D: 0.2420, Loss G: 4.7673
Epoch 20/30 | Loss D: 0.1435, Loss G: 4.2653
Epoch 21/30 | Loss D: 0.3541, Loss G: 4.2606
Epoch 22/30 | Loss D: 0.4636, Loss G: 4.8695
Epoch 23/30 | Loss 

In [None]:
def show_samples(generator):
    generator.eval()
    z = torch.randn(64, 100).to(device)
    samples = generator(z).cpu().detach()
    grid = np.transpose(torchvision.utils.make_grid(samples, nrow=8, normalize=True), (1, 2, 0))
    plt.figure(figsize=(8,8))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

show_samples(G)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x65 and 100x256)