In [62]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.nn as nn 
import os
from skimage import io, transform
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
from PIL import Image
import torch.nn.functional as F

In [63]:
class ImageNetDataset(Dataset):
    
    def __init__(self, root_dir, data_folder, transform=None):
        self.data_path = os.path.join(root_dir,data_folder)
        self.data = os.listdir(self.data_path)
        self.root = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_name = self.data[index]
        img_path = os.path.join(self.data_path,img_name)
        image = io.imread(img_path)
        if self.transform:
            image = self.transform(image)
        
        return image,img_name

In [64]:
class addGaussian:
    
    def __init__(self, std_range):
        self.std_range = std_range

    def __call__(self,img):
        
        if len(np.shape(img))==2:
            img = np.stack((img,)*3, axis=-1)
    
        minval,maxval = self.std_range
        noise_img = img.astype(np.float)
        stddev = np.random.uniform(minval, maxval)
        noise = np.random.randn(*img.shape) * stddev
        noise_img += noise
        noise_img = np.clip(noise_img, 0, 255).astype(np.uint8)
        return noise_img

In [65]:
class Rescale(object):

    def __call__(self, image):

        if len(image.shape) == 2:
            image = np.stack((image,)*3, axis=-1)
        resize = transforms.Resize((256,256))

        tens = transforms.ToTensor()
        
        img = tens(resize(image))
        print(type(img))
        return img


In [80]:
num_epochs = 200
batch_size = 4 
learning_rate = 1e-3

In [67]:
train_dataset = ImageNetDataset('/home/turing/Documents/BE/','data/', 
                               transform=transforms.Compose([
                                   addGaussian((0,50)),
                                   transforms.ToPILImage(),
                                   transforms.Resize((256,256)),
                                   transforms.ToTensor(),

                               ]))

In [68]:
test_dataset = ImageNetDataset('/home/turing/Documents/BE/','data/', 
                               transform=transforms.Compose([
                                   addGaussian((0,50)),
                                   transforms.ToPILImage(),
                                   transforms.Resize((256,256)),
                                   transforms.ToTensor(),
                               ]))

In [69]:
i=Image.open('/home/turing/Documents/BE/data/ILSVRC2012_val_00005899.JPEG')
print(i.size)
i = np.stack((i,)*3, axis=-1)
print(i.shape)

(500, 500)
(500, 500, 3)


In [70]:
def getTrainLoader(batch_size):
    train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True, num_workers=2)
    return(train_loader)
    

In [71]:
train_dataset[33]

(tensor([[[0.8157, 0.8627, 0.8314,  ..., 0.8235, 0.8667, 0.8392],
          [0.8471, 0.8275, 0.8235,  ..., 0.8392, 0.8588, 0.8510],
          [0.8745, 0.8824, 0.8902,  ..., 0.8706, 0.8275, 0.8392],
          ...,
          [0.4941, 0.4431, 0.4039,  ..., 0.3843, 0.3451, 0.4235],
          [0.4275, 0.4353, 0.3843,  ..., 0.3765, 0.3608, 0.3843],
          [0.3843, 0.4000, 0.3882,  ..., 0.3490, 0.3765, 0.3490]],
 
         [[0.8667, 0.8824, 0.8039,  ..., 0.8510, 0.8510, 0.8745],
          [0.8549, 0.9098, 0.8902,  ..., 0.8196, 0.8471, 0.8588],
          [0.8627, 0.8667, 0.8902,  ..., 0.8824, 0.8235, 0.7765],
          ...,
          [0.3725, 0.3725, 0.3059,  ..., 0.3373, 0.3569, 0.3412],
          [0.3686, 0.3843, 0.3451,  ..., 0.3569, 0.3216, 0.3020],
          [0.3333, 0.4157, 0.4000,  ..., 0.2745, 0.2941, 0.3569]],
 
         [[0.8118, 0.8392, 0.8784,  ..., 0.8588, 0.8745, 0.8275],
          [0.9137, 0.8784, 0.8627,  ..., 0.8784, 0.9137, 0.8588],
          [0.8902, 0.8824, 0.8667,  ...,

In [72]:
def getTestLoader(batch_size):
    test_loader = DataLoader(test_dataset, batch_size=batch_size,shuffle=True, num_workers=2)
    
    return(test_loader)
    

In [73]:
class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,48,3,padding=1)

        self.conv2 = nn.Conv2d(48,48,3,padding=1)

        self.pool = nn.MaxPool2d(2,2)

        self.conv3 = nn.Conv2d(48,48,3,padding=1)

        self.conv4 = nn.Conv2d(48,48,3,padding=1)


        self.conv5 = nn.Conv2d(48,48,3,padding=1)

        self.conv6 = nn.Conv2d(48,48,3,padding=1)


        self.upscale = nn.Upsample(scale_factor=2)

        self.decov1 = nn.Conv2d(96,96,3,padding=1)


        self.decov2 = nn.Conv2d(96,96,3,padding=1)

        self.decov3 = nn.Conv2d(144,96,3,padding=1)


        self.decov4 = nn.Conv2d(96,96,3,padding=1)


        self.decov5 = nn.Conv2d(144,96,3,padding=1)

        self.decov6 = nn.Conv2d(96,96,3,padding=1)


        self.decov7 = nn.Conv2d(144,96,3,padding=1)

        self.decov8 = nn.Conv2d(96,96,3,padding=1)


        self.decov9 = nn.Conv2d(99,64,3,padding=1)

        self.decov10 = nn.Conv2d(64,32,3,padding=1)

        self.conv7 = nn.Conv2d(32,3,3,padding=1)
        
    def forward(self, x):
        
        skips = [x]
        
        # Enc. Conv0
        x = F.leaky_relu(self.conv1(x))
        
        # Enc. Conv1
        x = F.leaky_relu(self.conv2(x))
        
        #pool1
        x = self.pool(x)
        skips.append(x)
        
        # Enc. Conv2 and p00l2
        x = self.pool(F.leaky_relu(self.conv3(x)))
        
        skips.append(x)
        
        # Enc. Conv3 and p00l3
        x = self.pool(F.leaky_relu(self.conv4(x)))
        
        skips.append(x)
        
        # Enc. Conv4 and pool4
        x = self.pool(F.leaky_relu(self.conv5(x)))
        
        skips.append(x)
        
        # Enc. Conv5 and pool5
        x = self.pool(F.leaky_relu(self.conv6(x)))
        
        # Enc. Conv6
        x = F.leaky_relu(self.conv6(x))
        
        #---------#
        #upsample 5
        x = self.upscale(x)
        
        # concat 5
        x = torch.cat([x, skips.pop()],1)        
        
        x = F.leaky_relu(self.decov1(x))
        
        x = self.upscale(F.leaky_relu_(self.decov2(x)))

        x = torch.cat([x, skips.pop()],1)
        
        x = F.leaky_relu(self.decov3(x))
        x = self.upscale(F.leaky_relu(self.decov4(x)))
        
        # concat 3
        x = torch.cat([x, skips.pop()],1)
        x = F.leaky_relu(self.decov5(x))
        
        x = self.upscale(F.leaky_relu(self.decov6(x)))
        # concat2
        x = torch.cat([x, skips.pop()],1)
        x = F.leaky_relu(self.decov7(x))
        x = self.upscale(F.leaky_relu(self.decov8(x)))
        
        #concat 1
        x = torch.cat([x, skips.pop()],1)
        x = F.leaky_relu(self.decov9(x))
        x = F.leaky_relu(self.decov10(x))
        
        
        x = self.conv7(x)
        
        return x

In [74]:
def createLossAndOptimizer(net, learning_rate=0.001):
    
    #Loss function
    loss = torch.nn.MSELoss()
    
    #Optimizer
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    
    return(loss, optimizer)

In [75]:
import time

In [76]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [77]:
print(device)

cuda:0


In [84]:
def trainNet(net, batch_size, n_epochs, learning_rate):
    
    print("===== HYPERPARAMETERS =====")
    print("batch_size=", batch_size)
    print("epochs=", n_epochs)
    print("learning_rate=", learning_rate)
    print("=" * 30)
    
    train_loader = getTrainLoader(batch_size)
    
    test_loader = getTestLoader(batch_size)
    
    n_batches = len(train_loader)
    
    loss, optimizer = createLossAndOptimizer(net, learning_rate)
    
    training_start_time = time.time()
    
    for epoch in range(n_epochs):
        
        running_loss = 0.0
        print_every = n_batches // 10
        start_time = time.time()
        total_train_loss = 0
        
        for i, data in enumerate(zip(train_loader, test_loader)):
            
            
            
            optimizer.zero_grad()
            inputs, labels = data
#             print("input - ",inputs[0].shape)
            inputs = inputs[0].float()
            labels = labels[0].float()
            
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = net(inputs)
#             print("output = ", outputs.shape)
            
            loss_size = loss(outputs, labels)
            loss_size.backward()
            optimizer.step()
            
            running_loss += loss_size.data[0]
            
            total_train_loss += loss_size.data[0]
            
            print(i,end='\r')
            
            if (i + 1) % (print_every + 1) == 0:
                print("Epoch {}, {:d}% \t train_loss: {:.2f} took: {:.2f}s".format(
                        epoch+1, int(100 * (i+1) / n_batches), running_loss / print_every, time.time() - start_time))
                #Reset running loss and time
                running_loss = 0.0
                start_time = time.time()

In [85]:
net = Net().to(device)
# net = Net()
trainNet(net, batch_size=batch_size, n_epochs=num_epochs, learning_rate=learning_rate)

===== HYPERPARAMETERS =====
batch_size= 4
epochs= 200
learning_rate= 0.001




Epoch 1, 10% 	 train_loss: 27740510.00 took: 657.43s
Epoch 1, 20% 	 train_loss: 76.77 took: 637.86s
3260

KeyboardInterrupt: 

Process Process-16:
Process Process-15:
Process Process-14:
Process Process-13:
Traceback (most recent call last):
  File "/home/turing/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/turing/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/turing/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/turing/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "<ipython-input-63-88826ae5edc9>", line 15, in __getitem__
    image = io.imread(img_path)
  File "/home/turing/anaconda3/envs/pytorch/lib/python3.6/site-packages/skimage/io/_io.py", line 62, in imread
    img = call_pl