In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import cv2
import os
import torch
import torchvision
import torchinfo
import skimage
import random

from torch.utils.data import Dataset, DataLoader

EPOCHS = 10#0
BATCH_SIZE = 256
LR = 2e-4

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Device: {device}")

In [None]:
class ImageDataset(Dataset):
  def __init__(self, X, Y):
    self.X = X
    self.Y = Y

  def __len__(self):
    return len(self.X)

  def __getitem__(self, index):
    return self.X[index], self.Y[index]

In [None]:
class Discriminator(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.first_conv2d = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    self.second_conv2d = torch.nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    self.first_batchnorm2d = torch.nn.BatchNorm2d(64)
    self.second_batchnorm2d = torch.nn.BatchNorm2d(64)
    self.leakyrelu = torch.nn.LeakyReLU(0.2)
    self.flatten = torch.nn.Flatten()
    self.linear = torch.nn.Linear(3136, 1)
    self.sigmoid = torch.nn.Sigmoid()

  def forward(self, x):
    x = self.first_conv2d(x)
    x = self.first_batchnorm2d(x)
    x = self.leakyrelu(x)
    x = self.second_conv2d(x)
    x = self.second_batchnorm2d(x)
    x = self.leakyrelu(x)
    x = self.flatten(x)
    x = self.linear(x)
    x = self.sigmoid(x)
    return x

In [None]:
class Generator(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.linear = torch.nn.Linear(100, 6272)
    self.relu = torch.nn.ReLU()
    self.first_convtranspose2d = torch.nn.ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    self.second_convtranspose2d = torch.nn.ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    self.first_batchnorm2d = torch.nn.BatchNorm2d(128)
    self.second_batchnorm2d = torch.nn.BatchNorm2d(128)
    self.conv2d = torch.nn.Conv2d(128, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    self.sigmoid = torch.nn.Sigmoid()

  def forward(self, x):
    x = self.linear(x)
    x = self.relu(x)
    x = torch.reshape(x, (-1, 128, 7, 7))
    x = self.first_convtranspose2d(x)
    x = self.first_batchnorm2d(x)
    x = self.relu(x)
    x = self.second_convtranspose2d(x)
    x = self.second_batchnorm2d(x)
    x = self.relu(x)
    x = self.conv2d(x)
    x = self.sigmoid(x)
    return x

In [None]:
def display_image(image, title, cmap="viridis"):
  plt.imshow(image, cmap=cmap)
  plt.title(title)
  plt.axis('off')
  plt.show()

In [None]:
def fix_image(image):
  image = image / 255
  image[image > 1] = 1
  image[image < 0] = 0
  image = image.astype(np.float32)
  return image

In [None]:
def expand(images):
  images = images[..., np.newaxis]
  images = images.transpose((0, 3, 1, 2))

  return images

In [None]:
def print_dataset_info(images, labels):
  print("==================================")
  print("PODATKI O ZBIRKI SLIK")
  print("==================================")
  print("IME ZBIRKE: MNIST")
  print(f"ŠTEVILO SLIK: {images.shape[0]}")
  print(f"ŠIRINA SLIK: {images.shape[1]}")
  print(f"VIŠINA SLIK: {images.shape[2]}")
  print("ŠTEVILO KANALOV SLIK: 1")
  print("PRIMERI SLIK:")

  plt.figure(figsize=(10, 10))
  for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.imshow(images[i], cmap="gray")
    plt.axis("off")
    plt.title(labels[i])
    plt.tight_layout()
    
  print("==================================")

In [None]:
def train(discriminator, generator, images):
  discriminator = discriminator.to(device)
  generator = generator.to(device)

  discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=LR)
  generator_optimizer = torch.optim.Adam(generator.parameters(), lr=LR)

  discriminator_loss_fn = torch.nn.BCELoss()
  generator_loss_fn = torch.nn.BCELoss()
  
  discriminator.train()
  generator.train()
      
  dataset = ImageDataset(images, images)
  dataloader = DataLoader(dataset=dataset, batch_size=int(BATCH_SIZE / 2), shuffle=True)

  for epoch in range(EPOCHS):
    #if (epoch + 1) % 10 == 0:
    print(f"TRAINING EPOCH {epoch + 1} OF {EPOCHS}")
    
    inputs = torch.from_numpy(np.random.normal(size=(int(BATCH_SIZE / 2), 100)).astype(np.float32)).to(device)
    
    for step, (input_batch, _) in enumerate(dataloader):
      input_batch = input_batch.to(device)
      first_target = torch.from_numpy(np.ones((int(BATCH_SIZE / 2), 1), dtype=np.float32)).to(device)
      second_target = torch.from_numpy(np.zeros((int(BATCH_SIZE / 2), 1), dtype=np.float32)).to(device)

      discriminator_optimizer.zero_grad()
      pred = discriminator(input_batch)
      loss = discriminator_loss_fn(pred, first_target)
      loss.backward()

      #inputs = torch.from_numpy(np.random.normal(size=(int(BATCH_SIZE / 2), 100)).astype(np.float32)).to(device)
      outputs = generator(inputs)

      pred = discriminator(outputs)
      loss = discriminator_loss_fn(pred, second_target)
      loss.backward()
      
      discriminator_optimizer.step()

      generator_optimizer.zero_grad()
      
      #inputs = torch.from_numpy(np.random.normal(size=(int(BATCH_SIZE / 2), 100)).astype(np.float32)).to(device)
      outputs = generator(inputs)

      pred = discriminator(outputs)
      loss = generator_loss_fn(pred, first_target)
      loss.backward()

      generator_optimizer.zero_grad()

  torch.save(discriminator, f"discriminator{EPOCHS}.pt")
  torch.save(generator, f"generator{EPOCHS}.pt")

In [None]:
dataset = torchvision.datasets.MNIST("datasets", download=True)

images = np.array([fix_image(np.array(el[0])) for el in dataset])[:59904]
labels = np.array([el[1] for el in dataset])[:25]

print_dataset_info(images, labels)
images = expand(images)

In [None]:
discriminator = Discriminator()
generator = Generator()

In [None]:
print(torchinfo.summary(discriminator, (128, 1, 28, 28)))

In [None]:
print(torchinfo.summary(generator, (1, 100)))

In [None]:
train(discriminator, generator, images)

In [None]:
model = torch.load("generator10.pt", map_location=device)
inputs = torch.from_numpy(np.random.normal(size=(1, 100)).astype(np.float32)).to(device)

pred = model(inputs)
pred = pred.cpu().detach().numpy()[0].transpose(1, 2, 0)

print(np.min(pred))
print(np.max(pred))

display_image(pred, 'output', 'gray')