In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
mnist_train = torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ]))

mnist_test = torchvision.datasets.MNIST('./files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ]))

In [4]:
train_dataloader = DataLoader(mnist_train, 32, shuffle=True)

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 11)
    
    # forward method
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, 0.5)
        x = F.relu(self.fc3(x))
        x = F.dropout(x, 0.5)
        x = F.softmax(self.fc4(x), dim=1)
        return x

In [6]:
D = Discriminator().to(device)

criterion = nn.CrossEntropyLoss()

# optimizer
D_optimizer = torch.optim.Adam(D.parameters(), lr = 0.0002)

In [7]:
def train(x, y) :
    D.zero_grad()

    batch_size = len(x)

    # Training on real data
    x_real, y_real = x.view(-1, 784), y.view(-1)
    x_real, y_real = x_real.to(device), y_real.to(device)

    D_output = D(x_real)
    D_loss = criterion(D_output, y_real)

    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [8]:
n_epoch = 10
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, y) in enumerate(train_dataloader):
        D_losses.append(train(x, y))

    print(f'{epoch}/{n_epoch}: loss_d: {torch.mean(torch.FloatTensor(D_losses))}')

1/10: loss_d: 1.7819561958312988
2/10: loss_d: 1.6515966653823853
3/10: loss_d: 1.6321799755096436
4/10: loss_d: 1.622334361076355
5/10: loss_d: 1.617701768875122
6/10: loss_d: 1.6131260395050049
7/10: loss_d: 1.6114306449890137
8/10: loss_d: 1.6073827743530273
9/10: loss_d: 1.6055384874343872
10/10: loss_d: 1.6068254709243774


In [12]:
# Calculating accuracy for each digit
accuracy = np.zeros(10)
correct = np.zeros(10)
total = np.zeros(10)

for x, y in mnist_test :
    output = D(x.view(-1, 784).to(device))

    if(torch.argmax(output.cpu().detach()) == y) :
        correct[y] += 1
    
    total[y] += 1

In [13]:
accuracy = correct/total

print(accuracy)

[0.97653061 0.98590308 0.89147287 0.90891089 0.95519348 0.92600897
 0.95407098 0.95719844 0.89322382 0.91674926]
