In [14]:
%load_ext autoreload
%autoreload 2

import torch
from torchvision import datasets
from torchvision import transforms as T
import ruamel.yaml as yaml
from pathlib import Path

import maxent_gan.models
from maxent_gan.models import MMCSNDiscriminator, MMCSNGenerator
from maxent_gan.models.utils import load_gan
from maxent_gan.utils.general_utils import DotConfig, CONFIGS_DIR, DATA_DIR

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
device = torch.device(1)

In [16]:
raw_config = yaml.round_trip_load(Path(CONFIGS_DIR, 'gan_configs', 'celeba-sngan-mmc.yml').open('r'))
config = DotConfig(raw_config['gan_config'])

gen, dis = load_gan(config, device, False)

In [18]:
import os
import zipfile 
import gdown
import torch
# from natsort import natsorted
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

## Setup
# Number of gpus available
ngpu = 1
device = torch.device('cuda:0' if (
    torch.cuda.is_available() and ngpu > 0) else 'cpu')

## Fetch data from Google Drive 
# Root directory for the dataset
data_root = '../data/celeba'
# Path to folder with the dataset
dataset_folder = f'{data_root}/img_align_celeba'
# URL for the CelebA dataset
url = 'https://drive.google.com/uc?id=1cNIac61PSA_LqDFYFUeyaQYekYPc75NH'
# Path to download the dataset to
download_path = f'{data_root}/img_align_celeba.zip'

# Create required directories 
if not os.path.exists(data_root):
  os.makedirs(data_root)
  os.makedirs(dataset_folder)

# Download the dataset from google drive
# gdown.download(url, download_path, quiet=False)

# # Unzip the downloaded file 
# with zipfile.ZipFile(download_path, 'r') as ziphandler:
#   ziphandler.extractall(dataset_folder)

In [20]:
class CelebADataset(Dataset):
    def __init__(self, root_dir, transform=None):
      """
      Args:
        root_dir (string): Directory with all the images
        transform (callable, optional): transform to be applied to each image sample
      """
      # Read names of images in the root directory
      image_names = os.listdir(root_dir)

      self.root_dir = root_dir
      self.transform = transform 
      self.image_names = image_names #natsorted(image_names)

    def __len__(self): 
      return len(self.image_names)

    def __getitem__(self, idx):
      # Get the path to the image 
      img_path = os.path.join(self.root_dir, self.image_names[idx])
      # Load image and convert it to RGB
      img = Image.open(img_path).convert('RGB')
      # Apply transformations to the image
      if self.transform:
        img = self.transform(img)

      return img

## Load the dataset 
# Path to directory with all the images
img_folder = f'{dataset_folder}/img_align_celeba'
# Spatial size of training images, images are resized to this size.
image_size = 64
# Transformations to be applied to each individual image sample
transform=transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                          std=[0.5, 0.5, 0.5])
])
# Load the dataset from file and apply transformations
celeba_dataset = CelebADataset(img_folder, transform)
