In [1]:
import numpy
import torch
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as dset
from torch.utils.data import Dataset

In [2]:
%load_ext autoreload
%autoreload 2


In [3]:
class Params:
    def __init__(self):
        self.batchSize = 64
        self.Stage1imageSize = 64
        self.Stage2imageSize = 128
        self.LAMBDA = 10
        self.lr= 0.0002
        self.nc = 3
        self.nz = 100
        self.ngf = 64
        self.ndf = 64
        self.dataroot = '/home/pytorch/projects/lsun'
        self.workers = 1
        self.restart = ''
        self.cuda = True
        self.beta1 = 0.5
opt = Params()

In [4]:
#Image superresolution [64x64->128x128]
#We would like to do supervised training so as to produce 128x128 images from 64x64 
#Together with this, we add adversarial terms. This is a rather roundabout way of reimplementing pix2pix maybe

In [None]:
#Pull out things from https://github.com/pytorch/vision/tree/master/torchvision/datasets to create 
#a new dataset class
#We want a dataloader that can emit both 64x64 and 128x128 data at the same time with 'enumerate'
#Then we make the generator produce 128x128 taking in 64x64 as input which we then train

In [5]:
def make_dataset(dir):
    import os
    images = []
    d = os.path.expanduser(dir)
    
    if not os.path.exists(dir):
        print('path does not exist')

    for root, _, fnames in sorted(os.walk(d)):
        for fname in sorted(fnames):
            path = os.path.join(root, fname)
            images.append(path)
    return images

In [6]:
def pil_loader(path):
    from PIL import Image
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

In [8]:
class ImageFolder(Dataset):
    """A generic data loader where the images are arranged in this way: ::
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png
    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
        is_valid_file (callable, optional): A function that takes path of an Image file
            and check if the file is a valid_file (used to check of corrupt files)
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    def __init__(self, root, transformA=None, transformB=None):
        
        self.samples = make_dataset(opt.dataroot)
        self.imgs = self.samples
        self.transformA = transformA
        self.transformB = transformB
        
    def __getitem__(self, index):
        """
        Args:
        index (int): Index
        Returns:
        tuple: (sample, target) where target is class_index of the target class.
        """
        path = self.samples[index]
        sample = pil_loader(path)
        
        if self.transformA is not None:
            sampleA = self.transformA(sample)
            
        if self.transformB is not None:
            sampleB = self.transformB(sample)
            
        return sampleA, sampleB

    def __len__(self):
        return len(self.samples)


In [9]:
transformA = transforms.Compose([
                                transforms.Resize(opt.Stage1imageSize),
                                transforms.CenterCrop(opt.Stage1imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

transformB = transforms.Compose([
                                transforms.Resize(opt.Stage2imageSize),
                                transforms.CenterCrop(opt.Stage2imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

In [10]:
dataset = ImageFolder(root=opt.dataroot,
                                transformA=transformA,transformB=transformB)

#Now we create a dataloader that dumps out both 64x64 and 128x128 when called with 'enumerate'
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))

In [None]:
#dataloaderStage1, dataloaderStage2 = get_data_loaders(opt)

In [None]:
%aimport model
#from model import G_Stage1
#from model import D_Stage1
from model import G_Stage2
from model import D_Stage2
from model import D_Stage2_4x4

In [None]:
opt.nc

In [None]:
G1 = G_Stage1(opt.nc, opt.nz, opt.ngf).cuda()
D1 = D_Stage1(opt.nc, opt.ndf).cuda()
G2 = G_Stage2(opt.nc, opt.ngf).cuda()
D2 = D_Stage2(opt.nc, opt.ndf).cuda()
D2_4x4 = D_Stage2_4x4(opt.nc, opt.ndf).cuda()


In [None]:
pwd

In [None]:
%aimport train2
from train2 import run_trainer2

In [None]:
run_trainer2(dataloaderStage2, G1, G2, D2, D2_4x4, opt)