Source: https://hardikbansal.github.io/CycleGANBlog/

In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import os
from PIL import Image
from torchvision import transforms
%matplotlib inline

## HYPER-PARAMETERS

In [10]:
a_folder = "/Users/tfolkman/Downloads/horse2zebra/trainA/"
b_folder = "/Users/tfolkman/Downloads/horse2zebra/trainB/"

In [53]:
class GANDataLoader(Dataset):
    def __init__(self, a_folder, b_folder, transform=None):
        self.a_folder = a_folder
        self.b_folder = b_folder
        self.a_images = os.listdir(a_folder)
        self.b_images = os.listdir(b_folder)
        self.transform = transform
        
    def __len__(self):
        return min(len(self.a_images), len(self.b_images))
    
    def __read_image(self, path):
        image = Image.open(path)
        if self.transform:
            image = self.transform(image)
        return image
    
    def __getitem__(self, idx):
        a_img_name = os.path.join(self.a_folder,
                                self.a_images[idx])
        b_img_name = os.path.join(self.b_folder,
                                self.b_images[idx])
        a_img = self.__read_image(a_img_name)
        b_img = self.__read_image(b_img_name)

        return a_img, b_img

In [59]:
data_loader = GANDataLoader(a_folder, b_folder, transform=transforms.Compose([transforms.Resize((256, 256)),
                                                                             transforms.ToTensor()]))

In [7]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 7, 1)
        self.conv2 = nn.Conv2d(64, 128, 3, 2)
        self.conv3 = nn.Conv2d(128, 256, 3, 2)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        
    def forward(self, input_image):
        output = F.relu(self.bn1(self.conv1(input_image)))
        output = F.relu(self.bn2(self.conv2(input_image)))
        return self.conv3(input_image)