In [2]:
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [3]:
import pickle
train_data_file = "./data/mnist.pkl"
with open(train_data_file, "rb") as f:
    data = pickle.load(f)

train_data, test_data = torch.Tensor(data["train"]), torch.Tensor(data["test"])

In [4]:
class Discriminator(nn.Module):
    def __init__(self, in_shape, hidden_size):
        super(Discriminator, self).__init__()
        self.in_shape = in_shape
        self.nin = np.prod(in_shape)
        self.hidden_size = hidden_size

        self.disc = nn.Sequential(nn.Linear(self.nin, self.hidden_size), self.LeakyReLU(0.1), nn.Linear(self.hidden_size, 1), nn.Sigmoid())
    def forward(self, X):
        return self.disc(X)

class Generator(nn.Module):
    def __init__(self, in_shape, hidden_size, latent_dim):
        super(Generator, self).__init__()
        self.in_shape = in_shape
        self.hidden_size = hidden_size
        self.nin = np.prod(in_shape)
        self.latent_dim = latent_dim

        self.gen = nn.Sequential(nn.Linear(self.latent_dim, self.hidden_size), nn.LeakyReLU(0.1), nn.Linear(self.hidden_size, self.nin), nn.Tanh())
    def forward(self, X):
        return self.gen(X)

class GAN(nn.Module):
    def __init__(self, in_shape, hidden_size, latent_dim):
        super(GAN, self).__init__()
        self.in_shape = in_shape
        self.hidden_size = hidden_size
        self.latent_dim = latent_dim

        self.gen = Generator(self.in_shape, self.hidden_size, self.latent_dim)
        self.disc = Generator(self.in_shape, self.hidden_size, self.latent_dim)

    def forward(self, X):
        X = X.view(X.shape[0], -1)

        generated = self.gen(X)

        out = self.disc(generated)

        return out

        

In [None]:
num_epochs = 20
rec_loss = nn.MSELoss()
device = "cuda" if torch.cuda.is_available() else "cpu"

