<a href="https://colab.research.google.com/github/shahaman06/CMPE297-ShortStory/blob/main/Colab_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Intial Files Setup

!rm -r *
!git clone "https://github.com/shahaman06/CMPE297-ShortStory/"
!mv /content/CMPE297-ShortStory/* /content/
!rm -r CMPE297-ShortStory README.md Colab_Notebook.ipynb

In [None]:
# import torch.nn as nn
# import torchvision
import os
import cv2
import numpy as np

from torch.cuda import is_available
from torch.utils.data import DataLoader,Dataset
from torch.optim import Adam
from torch.nn import BCELoss
from torch import ones, randn, tensor
from torchvision import transforms
from modules import Discriminator, Generator
from tqdm import tqdm
from glob import glob

device = 'cuda' if is_available() else 'cpu'

In [None]:
# Declaring GLOBAL VARS:
BATCH_SIZE = 25
IMG_SHAPE = 208
CHANNELS = 3
FEATURES = 16
RANDOM_NOISE_SHAPE = 256

In [None]:
class PaintingDataset(Dataset):
    def __init__(self, loc = os.path.join('/content/paintings/'), img_shape = 208):
        self.loc = loc
        self.img_shape = img_shape
        self.paintings=[]
        # removing invalid image files
        for i in glob(os.path.join(loc+'*.jpg')):
          img = cv2.imread(i)
          if type(img) != type(None):
            self.paintings.append(i)

    def __len__(self):
        return len(self.paintings)
    
    def __getitem__(self, idx):
        img = self.paintings[idx]
        img = cv2.imread(img)
        img = cv2.resize(img, (self.img_shape, self.img_shape))
        img = np.moveaxis(img, -1, 0) # pytorch takes channel first images
        img = tensor(img).float()
        return img.to(device)

In [None]:
ds = PaintingDataset(img_shape = IMG_SHAPE,)
loader = DataLoader(ds, shuffle = True, batch_size = BATCH_SIZE)
batch = next(iter(loader))

In [None]:
disc = Discriminator().to(device)
gen = Generator().to(device)

lr = 0.0002
epochs = 2

optimD = Adam(disc.parameters(), lr = lr, betas = (0.5, 0.99))
optimG = Adam(gen.parameters(), lr = lr, betas = (0.5, 0.99))

criterion = BCELoss()

real_label = 1
fake_label = 0

for data in tqdm(loader):
    batch_size = data.shape[0]

    ## Training discriminator

    '''First feeding the real images'''
    label = (ones(batch_size)* 0.9).to(device)
    output = disc(data).reshape(-1)
    lossD_real = criterion(output, label)

    '''feeding generated images'''
    label = (ones(batch_size) * 0.1).to(device)
    rand_noise = randn((batch_size, RANDOM_NOISE_SHAPE, 1, 1)).to(device)
    fake_image = gen(rand_noise)
    output = disc(fake_image.detach()).reshape(-1)
    lossD_fake = criterion(output, label)

    '''Back propogating discriminator and updating weights'''
    disc.zero_grad()
    lossD = lossD_real + lossD_fake
    lossD.backward()
    optimD.step()

    ## Training generator

    label = ones(batch_size).to(device)
    output = disc(fake_image).reshape(-1)
    lossG = criterion(output, label)

    '''Backpropogating'''
    gen.zero_grad()
    lossG.backward()
    optimG.step()