In [5]:
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results

Random Seed:  999


In [None]:
device = torch.device("mps")

# Batch size during training
batch_size = 16

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 20

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

### Dataloader

In [None]:
class MNISTDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        data = pd.read_csv(csv_file)
        self.data = data[data.iloc[:, 0] == 7]  # only select sneakers
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # label is contained in first column of .csv file
        # image pixels are in the remaining columns
        image = self.data.iloc[idx, 1:].values.astype('uint8').reshape((28, 28, 1)) 
        label = self.data.iloc[idx, 0] 
        if self.transform:
            image = self.transform(image)
        return image, label

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  
])

dataset = MNISTDataset("../data/fashion-mnist_train.csv", transform=transform)

dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)

real_batch, _ = next(iter(dataloader))  
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))