In [1]:
import numpy
import torch
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as dset
from torch.utils.data import Dataset

In [2]:
%load_ext autoreload
%autoreload 2


In [3]:
pwd

'/home/pytorch/projects/SR_regular_GAN'

In [4]:
class Params:
    def __init__(self):
        self.batchSize = 64
        self.Stage1imageSize = 64
        self.Stage2imageSize = 64
        self.LAMBDA = 10
        self.lr= 0.0002
        self.nc = 3
        self.nz = 100
        self.ngf = 64
        self.ndf = 64
        #for unet 
        self.nc_out = 3
        self.num_downsample = 4 
        self.dataroot = '/home/pytorch/projects/lsun'
        self.workers = 1
        self.restart = ''
        self.cuda = True
        self.beta1 = 0.5
opt = Params()

In [5]:
#Image superresolution [64x64->128x128]
#We would like to do supervised training so as to produce 128x128 images from 64x64 
#Together with this, we add adversarial terms. This is a rather roundabout way of reimplementing pix2pix maybe

In [6]:
#Pull out things from https://github.com/pytorch/vision/tree/master/torchvision/datasets to create 
#a new dataset class
#We want a dataloader that can emit both 64x64 and 128x128 data at the same time with 'enumerate'
#Then we make the generator produce 128x128 taking in 64x64 as input which we then train

In [7]:
def make_dataset(dir):
    import os
    images = []
    d = os.path.expanduser(dir)
    
    if not os.path.exists(dir):
        print('path does not exist')

    for root, _, fnames in sorted(os.walk(d)):
        for fname in sorted(fnames):
            path = os.path.join(root, fname)
            images.append(path)
    return images

In [8]:
def pil_loader(path):
    from PIL import Image
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

In [9]:
class ImageFolder(Dataset):
    """A generic data loader where the images are arranged in this way: ::
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png
    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
        is_valid_file (callable, optional): A function that takes path of an Image file
            and check if the file is a valid_file (used to check of corrupt files)
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    def __init__(self, root, transformA=None, transformB=None):
        
        self.samples = make_dataset(opt.dataroot)
        self.imgs = self.samples
        self.transformA = transformA
        self.transformB = transformB
        
    def __getitem__(self, index):
        """
        Args:
        index (int): Index
        Returns:
        tuple: (sample, target) where target is class_index of the target class.
        """
        path = self.samples[index]
        sample = pil_loader(path)
        
        if self.transformA is not None:
            sampleA = self.transformA(sample)
            
        if self.transformB is not None:
            sampleB = self.transformB(sample)
            
        return sampleA, sampleB

    def __len__(self):
        return len(self.samples)


In [10]:
transformA = transforms.Compose([
                                transforms.Resize(opt.Stage1imageSize),
                                transforms.CenterCrop(opt.Stage1imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

transformB = transforms.Compose([
                                transforms.Resize(opt.Stage2imageSize),
                                transforms.CenterCrop(opt.Stage2imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

In [11]:
dataset = ImageFolder(root=opt.dataroot,
                                transformA=transformA,transformB=transformB)

#Now we create a dataloader that dumps out both 64x64 and 128x128 when called with 'enumerate'
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))

In [12]:
#dataloaderStage1, dataloaderStage2 = get_data_loaders(opt)

In [13]:
%aimport model
#from model import G_Stage1
#from model import D_Stage1
#from model import G_Stage2
from model import D_Stage2
from model import get_unet_generator
#from model import UnetGenerator
#from model import D_Stage2_4x4

In [14]:
opt.nc

3

In [15]:
D2 = D_Stage2(opt.nc,opt.ndf)
G2 = get_unet_generator(opt.nc, opt.nc_out, opt.num_downsample)

if opt.cuda:
    D2 = D2.cuda()
    G2 = G2.cuda()

In [16]:
print(G2)

UnetGenerator(
  (model): UnetSkipConnectionBlock(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): UnetSkipConnectionBlock(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace)
          (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): UnetSkipConnectionBlock(
            (model): Sequential(
              (0): LeakyReLU(negative_slope=0.2, inplace)
              (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): UnetSkipConnectionBlock(
                (model): Sequential(
                  (0): LeakyReLU(negative_slope=0.2, inplace)
                  (1): Conv2d(256, 512, kernel_size=(4

In [25]:
%aimport train2
from train2 import run_trainer2

In [26]:
run_trainer2(dataloader, G2, D2, opt)

saving images for batch 0
0 [0/47392] G Loss [L1/GAdv] [0.4266/0.0222] Loss D (real/fake) [0.2195/0.2528]
saving images for batch 100
0 [100/47392] G Loss [L1/GAdv] [0.0887/0.0011] Loss D (real/fake) [0.2319/0.2038]
saving images for batch 200
0 [200/47392] G Loss [L1/GAdv] [0.0736/0.0011] Loss D (real/fake) [0.1736/0.2192]
saving images for batch 300
0 [300/47392] G Loss [L1/GAdv] [0.0672/0.0007] Loss D (real/fake) [0.2626/0.1663]
saving images for batch 400
0 [400/47392] G Loss [L1/GAdv] [0.0595/0.0005] Loss D (real/fake) [0.3125/0.1313]
saving images for batch 500
0 [500/47392] G Loss [L1/GAdv] [0.0565/0.0011] Loss D (real/fake) [0.1244/0.3023]
saving images for batch 600
0 [600/47392] G Loss [L1/GAdv] [0.0549/0.0014] Loss D (real/fake) [0.0798/0.4299]
saving images for batch 700
0 [700/47392] G Loss [L1/GAdv] [0.0501/0.0015] Loss D (real/fake) [0.1501/0.2797]
saving images for batch 800
0 [800/47392] G Loss [L1/GAdv] [0.0440/0.0006] Loss D (real/fake) [0.1667/0.3038]
saving images 

saving images for batch 7400
0 [7400/47392] G Loss [L1/GAdv] [0.0155/0.0001] Loss D (real/fake) [0.2040/0.2664]
saving images for batch 7500
0 [7500/47392] G Loss [L1/GAdv] [0.0150/0.0001] Loss D (real/fake) [0.2077/0.2504]
saving images for batch 7600
0 [7600/47392] G Loss [L1/GAdv] [0.0151/0.0001] Loss D (real/fake) [0.2232/0.2402]
saving images for batch 7700
0 [7700/47392] G Loss [L1/GAdv] [0.0148/0.0001] Loss D (real/fake) [0.2150/0.2613]
saving images for batch 7800
0 [7800/47392] G Loss [L1/GAdv] [0.0155/0.0002] Loss D (real/fake) [0.2439/0.1981]
saving images for batch 7900
0 [7900/47392] G Loss [L1/GAdv] [0.0161/0.0001] Loss D (real/fake) [0.2894/0.1858]
saving images for batch 8000
0 [8000/47392] G Loss [L1/GAdv] [0.0148/0.0001] Loss D (real/fake) [0.2145/0.2404]
saving images for batch 8100
0 [8100/47392] G Loss [L1/GAdv] [0.0159/0.0001] Loss D (real/fake) [0.3311/0.1501]
saving images for batch 8200
0 [8200/47392] G Loss [L1/GAdv] [0.0145/0.0001] Loss D (real/fake) [0.1574/

Process Process-1:
Traceback (most recent call last):
  File "/home/pytorch/anaconda3/envs/pytorch37/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/pytorch/anaconda3/envs/pytorch37/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/pytorch/anaconda3/envs/pytorch37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/home/pytorch/anaconda3/envs/pytorch37/lib/python3.7/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/home/pytorch/anaconda3/envs/pytorch37/lib/python3.7/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/home/pytorch/anaconda3/envs/pytorch37/lib/python3.7/multiprocessing/connection.py", line 414, in _poll
    r = wait([self], timeout)
  File "/home/pytorch/anaconda3/envs/pytorch37/

KeyboardInterrupt: 