In [3]:
import torch
import torchvision

from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
from torch import nn, optim
from torchvision import datasets,transforms
from torch.autograd import Variable
from PIL import Image


In [4]:
import math
irange = range


def make_grid(tensor, nrow=8, padding=2,
              normalize=False, range=None, scale_each=False, pad_value=0):
    """Make a grid of images.
    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
        nrow (int, optional): Number of images displayed in each row of the grid.
            The Final grid size is (B / nrow, nrow). Default is 8.
        padding (int, optional): amount of padding. Default is 2.
        normalize (bool, optional): If True, shift the image to the range (0, 1),
            by subtracting the minimum and dividing by the maximum pixel value.
        range (tuple, optional): tuple (min, max) where min and max are numbers,
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
        scale_each (bool, optional): If True, scale each image in the batch of
            images separately rather than the (min, max) over all images.
        pad_value (float, optional): Value for the padded pixels.
    Example:
        See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
    """
    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))

    # if list of tensors, convert to a 4D mini-batch Tensor
    if isinstance(tensor, list):
        tensor = torch.stack(tensor, dim=0)

    if tensor.dim() == 2:  # single image H x W
        tensor = tensor.view(1, tensor.size(0), tensor.size(1))
    if tensor.dim() == 3:  # single image
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
            tensor = torch.cat((tensor, tensor, tensor), 0)
        tensor = tensor.view(1, tensor.size(0), tensor.size(1), tensor.size(2))

    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
        tensor = torch.cat((tensor, tensor, tensor), 1)

    if normalize is True:
        tensor = tensor.clone()  # avoid modifying tensor in-place
        if range is not None:
            assert isinstance(range, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, min, max):
            img.clamp_(min=min, max=max)
            img.add_(-min).div_(max - min + 1e-5)

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
                norm_ip(t, float(t.min()), float(t.max()))

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range)
        else:
            norm_range(tensor, range)

    if tensor.size(0) == 1:
        return tensor.squeeze()

    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
    grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_(pad_value)
    k = 0
    for y in irange(ymaps):
        for x in irange(xmaps):
            if k >= nmaps:
                break
            grid.narrow(1, y * height + padding, height - padding)\
                .narrow(2, x * width + padding, width - padding)\
                .copy_(tensor[k])
            k = k + 1
    return grid


def save_image(tensor, filename, nrow=8, padding=2,
               normalize=False, range=None, scale_each=False, pad_value=0):
    """Save a given Tensor into an image file.
    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
        **kwargs: Other arguments are documented in ``make_grid``.
    """
    from PIL import Image
    grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
                     normalize=normalize, range=range, scale_each=scale_each)
    ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
    im = Image.fromarray(ndarr)
    im.save(filename)


In [5]:
image_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.24703223,  0.24348513 , 0.26158784))
])

train_set = datasets.CIFAR10(root='./data',train=True,download=True,transform=image_transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=16,
                                          shuffle=True, num_workers=2)

test_set = datasets.CIFAR10(root='./data',train=False,download=True,transform=image_transform)

test_loader = torch.utils.data.DataLoader(test_set,batch_size=16,
                                          shuffle=True,num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [6]:
HEIGHT = 32
WIDTH = 32
EPOCH = 10
LOG_INTERVAL = 500


In [7]:
class Codex(nn.Module):
    def __init__(self,size,mode):
        super(Codex,self).__init__()
        self.size = size
        self.mode = mode
        self.encode_decode = nn.functional.interpolate
        
    def forward(self,x):
        x = self.encode_decode(x,size=self.size,mode=self.mode)
        return x

class EndToEnd(nn.Module):
    def __init__(self,channel,height,mode):
        super(EndToEnd,self).__init__()
        
        self.conv1 = nn.Conv2d(channel,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(64,64,kernel_size=3,stride=2,padding=0)
        self.bn1 = nn.BatchNorm2d(64,affine=False)
        self.conv3 = nn.Conv2d(64, channel, kernel_size=3, stride=1, padding=1)
        
        self.interpolate = Codex(size=height,mode=mode)
        self.deconv1 = nn.Conv2d(channel,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(64,affine=False)
        
        self.deconv_n = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
        self.bn_n = nn.BatchNorm2d(64,affine=False)
        
        
        self.deconv3 = nn.ConvTranspose2d(64,channel,kernel_size=3,stride=1,padding=1)
        
        self.relu = nn.ReLU()
    
    def reparameterize(self, mu, logvar):
        pass
    
    def forward_comcnn(self,x):
        out = self.relu(self.conv1(x))
        out = self.relu(self.bn1(self.conv2(out)))
        return self.conv3(out)
    
    def forward_reccnn(self,z):
        decoded_image = self.interpolate(z)
        out = self.relu(self.deconv1(decoded_image))
        for _ in range(18):
            out = self.relu(self.bn_n(self.deconv_n(out)))
        out = self.deconv3(out)
        final = out + decoded_image
        return final,out,decoded_image
    
    def forward(self, x):
        com_img = self.forward_comcnn(x)
        final,out,upscaled = self.forward_reccnn(com_img)
        return final,out,upscaled,com_img,x

In [8]:
CUDA = torch.cuda.is_available()

if CUDA:
    print("Cuda is avaliable,using cuda instead of cpu")
    model = EndToEnd(3,32,'bilinear').cuda()
else:
    print("Cuda is not available, using cpu")
    model = EndToEnd(3,32,'bilinear')
    
optimizer = optim.Adam(model.parameters(),lr=1e-3)

Cuda is avaliable,using cuda instead of cpu


In [9]:
def loss_function(final_image,residual_image,upscaled_image,com_image,original_image):
    com_loss = nn.MSELoss(size_average=False)(original_image,final_image)
    rec_loss = nn.MSELoss(size_average=False)(residual_image,original_image-upscaled_image)
    return com_loss+rec_loss


In [10]:
def train(epoch):
    model.train()
    train_loss  =0
    for batch_idx,(data,_) in enumerate(train_loader):
        data = Variable(data)
        optimizer.zero_grad()
        final, residual_img, upscaled_image, com_img, orig_im = model(data.cuda())
        loss = loss_function(final, residual_img, upscaled_image, com_img, orig_im)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    
    

In [11]:
def test(epoch):
  
  model.eval()
  test_loss = 0
  for i, (data, _) in enumerate(test_loader):
        data = Variable(data, volatile=True)
        final, residual_img, upscaled_image, com_img, orig_im = model(data.cuda())
        test_loss += loss_function(final, residual_img, upscaled_image, com_img, orig_im).item()
        if epoch == EPOCH and i == 0:
#             save_image(final.data[0],'reconstruction_final',nrow=8)
#             save_image(com_img.data[0],'com_img',nrow=8)
            n = min(data.size(0), 6)
            print("saving the image "+str(n))
            comparison = torch.cat([data[:n],
              final[:n].cpu()])
            comparison = comparison.cpu()
#             print(comparison.data)
            save_image(com_img[:n].data,
                       'compressed_' + str(epoch) +'.png', nrow=n)
            save_image(comparison.data,
                       'reconstruction_' + str(epoch) +'.png', nrow=n)

  test_loss /= len(test_loader.dataset)
  print('====> Test set loss: {:.4f}'.format(test_loss))

In [52]:
import time
start = time.time()
for epoch in range(1,25+1):
    temp = time.time()
    train(epoch)
    test(epoch)
    if epoch == EPOCH:
        pass
    temp1 = time.time()
    print("DONE---> total time ",temp1-temp)

end = time.time()
torch.save(model.state_dict(),'./net.pth')
 

====> Epoch: 1 Average loss: 629.0249
====> Test set loss: 562.1685
DONE---> total time  176.8578429222107
====> Epoch: 2 Average loss: 564.1851
====> Test set loss: 555.5524
DONE---> total time  176.28355932235718
====> Epoch: 3 Average loss: 328.2454
====> Test set loss: 651.9151
DONE---> total time  176.36132955551147
====> Epoch: 4 Average loss: 238.0292
====> Test set loss: 612.5775
DONE---> total time  176.5442967414856
====> Epoch: 5 Average loss: 216.9310
====> Test set loss: 616.6415
DONE---> total time  175.5963749885559
====> Epoch: 6 Average loss: 205.5701
====> Test set loss: 653.0567
DONE---> total time  179.42620420455933
====> Epoch: 7 Average loss: 199.2257
====> Test set loss: 673.7499
DONE---> total time  182.05132675170898
====> Epoch: 8 Average loss: 199.4300
====> Test set loss: 662.4085
DONE---> total time  175.74607920646667
====> Epoch: 9 Average loss: 179.5272
====> Test set loss: 617.9073
DONE---> total time  176.7314534187317
====> Epoch: 10 Average loss: 17

  


In [53]:
print("Total Time Execution ",end-start)

Total Time Execution  4368.165201663971


In [13]:

model.load_state_dict(torch.load('net.pth'))

def save_images():
  epoch = EPOCH
  model.eval()
  test_loss = 0
  for i, (data, _) in enumerate(test_loader):
        data = Variable(data, volatile=True)
        final, residual_img, upscaled_image, com_img, orig_im = model(data.cuda())
        test_loss += loss_function(final, residual_img, upscaled_image, com_img, orig_im).item()
        if i == 3:
#             save_image(final.data[0],'reconstruction_final',nrow=8)
#             save_image(com_img.data[0],'com_img',nrow=8)
            n = min(data.size(0), 6)
            print("saving the image "+str(n))
            comparison = torch.cat([data[:n],
              final[:n].cpu()])
            comparison = comparison.cpu()
#             print(comparison.data)
            save_image(com_img[:1].data,
                         'compressed_' + str(i) +'.png', nrow=n)
            save_image(final[:1].data,
                        'final_' + str(epoch) +'.png', nrow=n)
            save_image(orig_im[:1].data,
                        'original_' + str(epoch) +'.png', nrow=n)


  test_loss /= len(test_loader.dataset)
  print('====> Test set loss: {:.4f}'.format(test_loss))

save_images()

  
  "See the documentation of nn.Upsample for details.".format(mode))


saving the image 6
====> Test set loss: 691.6640


In [15]:
import numpy 
import math
import cv2
original = cv2.imread("original_10.png")
contrast = cv2.imread("final_10.png",1)
def psnr(img1, img2):
    mse = numpy.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

d=psnr(original,contrast)
print(d)

32.90015044758086
