In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
batch_size = 64
latent_size = 64
image_size = 28 * 28
missing_rate = 0.5
num_epochs = 5
learning_rate = 0.0002

In [19]:
transform = transforms.Compose([
    #transforms.Resize((28,28),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [20]:
train_dataset = FashionMNIST(root='./data/', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)


In [21]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_size, 128),
            nn.ReLU(),
            nn.Linear(128, image_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [22]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [24]:
d_losses = []
g_losses = []

total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        # Move images to device
        images = images.to(device)
        
        # Mask input images
        mask = torch.rand_like(images) > missing_rate
        masked_images = images.clone()
        masked_images[~mask] = 0
        
        # Create labels for real and fake images
        real_labels = torch.ones(images.size(0), 1, device=device)
        fake_labels = torch.zeros(images.size(0), 1, device=device)
        
        # Train discriminator
        optimizer_D.zero_grad()
        outputs_real = discriminator(images.view(-1, image_size))
        real_loss = criterion(outputs_real, real_labels)
        
        outputs_fake = discriminator(masked_images.view(-1, image_size))
        fake_loss = criterion(outputs_fake, fake_labels)
        
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()
        
        # Train generator
        optimizer_G.zero_grad()
        inpainted_images = generator(masked_images)
        outputs = discriminator(inpainted_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()
        
        # Append losses to lists
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
        
    print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1792x28 and 64x128)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(d_losses, label='Discriminator Loss')
plt.plot(g_losses, label='Generator Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('GAN Loss Curve')
plt.legend()
plt.show()

In [None]:
# Save real and reconstructed images
images_concat = torch.cat([images.view(-1, 1, 28, 28), inpainted_images.view(-1, 1, 28, 28)], dim=3)
save_image(images_concat, 'gan_images/{}_{}.png'.format(epoch+1, i+1), nrow=8, normalize=True)