In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
from tqdm import tqdm 
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.pyplot import figure

In [15]:
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

In [16]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()        
        # N, 1, 28, 28
        self.disc = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1), # -> N, 16, 14, 14
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), # -> N, 32, 7, 7
            nn.ReLU(),
            nn.Conv2d(32, 64, 7), # -> N, 64, 1, 1
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

In [17]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()        
        # N, 1, 28, 28
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1), # -> N, 16, 14, 14
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), # -> N, 32, 7, 7
            nn.ReLU(),
            nn.Conv2d(32, 64, 7) # -> N, 64, 1, 1
        )
        
        # N , 64, 1, 1
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7), # -> N, 32, 7, 7
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), # N, 16, 14, 14 (N,16,13,13 without output_padding)
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1), # N, 1, 28, 28  (N,1,27,27)
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [18]:
# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# LEARNING_RATE_1 = 2e-4
LEARNING_RATE_2 = 1e-3 # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 32
IMAGE_SIZE = 28


In [19]:
transforms = transforms.Compose(
    [
        
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(1)], [0.5 for _ in range(1)]
        ),
    ]
)

In [20]:
train_dataset = datasets.FashionMNIST(root="dataset/",train=True, transform=transforms, download=True)

test_dataset = datasets.FashionMNIST(root="dataset/",train= False, transform=transforms, download=True)

In [21]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [22]:
gen = Generator().to(device)
disc_Normal = Discriminator().to(device)
disc_Anomaly = Discriminator().to(device)


In [23]:
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE_2, weight_decay=1e-5)
opt_disc_N = optim.Adam(disc_Normal.parameters(), lr=LEARNING_RATE_2, weight_decay=1e-5)
opt_disc_A = optim.Adam(disc_Anomaly.parameters(), lr=LEARNING_RATE_2, weight_decay=1e-5)
criterion = nn.BCELoss()

In [24]:
NUM_EPOCHS = 10
outputs = []
for epoch in range(NUM_EPOCHS):
    for (img, _) in tqdm(train_loader):
        img=img.to(device)
        recon = gen(img)
        loss = criterion(recon, img)
        opt_gen.zero_grad()
        loss.backward()
        opt_gen.step()

    print(f'Epoch:{epoch+1}, Loss:{loss.item():.6f}')
    outputs.append((epoch, img, recon))

  0%|          | 0/1875 [00:00<?, ?it/s]


RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR