In [1]:
print("Importing Library...")
import torch
import torchvision
import torch.utils.data as putils

from torch import nn, optim
from torchvision import datasets,transforms
from torch.autograd import Variable
from PIL import Image
import numpy
import numpy as np
import math

print("Importing Library Success")


Importing Library...
Importing Library Success


In [2]:
print("Defining Class...")
class ComCNN(nn.Module):
    def __init__(self, channel):
        super(ComCNN, self).__init__()
        self.conv1 = nn.Conv2d(channel, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64, affine=False)
        self.conv3 = nn.Conv2d(64, channel, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.relu(self.bn1(self.conv2(out)))
        return self.conv3(out)

class MinMaxCNN(nn.Module):
    def __init__(self, channel,interpolate_size=2,mode='bicubic',deep=3):
        super(MinMaxCNN, self).__init__()
        self.deconv1 = nn.Conv2d(channel, out_channels=64,stride=4, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64, affine=False)
        self.deconv_n = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn_n = nn.BatchNorm2d(64, affine=False)
        self.deconv3 = nn.ConvTranspose2d(64, 2, kernel_size=3,padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.relu(self.deconv1(x))
        out = self.relu(self.bn_n(self.deconv_n(out)))
        out = self.deconv3(out)
        return self.relu(out)

class BitmapRecCNN(nn.Module):
    def __init__(self, channel,interpolate_size=2,mode='bicubic',deep=5):
        super(BitmapRecCNN, self).__init__()
        self.deconv1 = nn.Conv2d(channel, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64, affine=False)
        self.deconv_n = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn_n = nn.BatchNorm2d(64, affine=False)
        self.deconv3 = nn.ConvTranspose2d(64, channel, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        self.interpolate_size = interpolate_size
        self.mode = mode
        self.deep = 5

    def forward(self, x):
        x = self.interpolate(x)
        out = self.relu(self.deconv1(x))
        for _ in range(self.deep):
            out = self.relu(self.bn_n(self.deconv_n(out)))
        out = self.sigmoid(self.deconv3(out))
        rounding = self.rounding_layer(out)
        return rounding

    def interpolate(self,x):
        return nn.functional.interpolate(input=x, scale_factor=self.interpolate_size, mode=self.mode,
                                         align_corners=False)

    def rounding_layer(self,batch_image):
        return batch_image.round()
    
    
class Network(nn.Module):
    def __init__(self,comCNN,bitmapCNN,minmaxCNN):
        super(Network,self).__init__()
        self.first = comCNN
        self.second = bitmapCNN
        self.third = minmaxCNN
        
    def forward(self, x):
        compact_repre = self.first(x.cuda())
        
        bitmap_repre = self.second(compact_repre.cuda())
        
        minmax_repre = self.third(compact_repre.cuda())
        
        recon = self.btc(bitmap_repre,minmax_repre)
        
        return recon,minmax_repre,bitmap_repre,compact_repre
    
    def btc(self,bitmap,minmax):
        for i in range(len(minmax)):
            counterx = 0
            for j in range(0,len(bitmap[0][0]),4):
                countery = 0
                for k in range(0,len(bitmap[0][0][0]),4):
                    if bitmap[i][0][j][k] == 1:
                        bitmap[i][0][j][k] = minmax[i][1][counterx][countery]
                    else:
                        bitmap[i][0][j][k] = minmax[i][0][counterx][countery]
                countery = countery + 1
            counterx = counterx + 1
                
        return bitmap

def loss(original_image,reconstructed_image):
    return torch.nn.MSELoss(size_average=False)(reconstructed_image,original_image)

def psnr(img1, img2):
    mse = numpy.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 1
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

print("Defining Class Succes")
    

Defining Class...
Defining Class Succes


In [7]:
print("Load Image Dataset")
train_image_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.CenterCrop(size=(32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

test_image_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.CenterCrop(size=(32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])


train_path = '../dataset/train/'
test_path = '../dataset/test/'
train_dataset = torchvision.datasets.ImageFolder(
    root=train_path,
    transform=train_image_transform
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=24,
                                          shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.ImageFolder(
    root=test_path,
    transform=test_image_transform
)

test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=24,
                                          shuffle=True,num_workers=2)

print("Load Image Dataset Success")



Load Image Dataset
Load Image Dataset Success


In [10]:
print("Initialize Model....")
CUDA = torch.cuda.is_available()

if CUDA:
    comCNN = ComCNN(1).cuda()
    bitmapRecCNN = BitmapRecCNN(1).cuda()
    minMaxCNN = MinMaxCNN(1).cuda()
    network = Network(bitmapCNN=bitmapRecCNN,comCNN=comCNN,minmaxCNN=minMaxCNN).cuda()
    print("Cuda is available, using gpu instead")
else:
    comCNN = ComCNN(1)
    bitmapRecCNN = BitmapRecCNN(1)
    minMaxCNN = MinMaxCNN(1).cuda()
    network = Network(bitmapCNN=bitmapRecCNN,comCNN=comCNN,minmaxCNN=minMaxCNN)
    print("Cuda is not available, using cpu instead")

optimizer = optim.Adam(network.parameters(),lr=5)


print("Initialize Model Success")



Initialize Model....
Cuda is available, using gpu instead
Initialize Model Success


In [11]:
import time
print("Begin Training...")
epoch = 10
network.train()
for i in range(epoch):
    loss_temp = 0
    psnr_avg = 0
    start = time.time()
    for batch_idx,(data,_) in enumerate(train_loader):
        optimizer.zero_grad()
        data = Variable(data)
        recon,minmax,bitmap,compact = network(data.cuda())
        
        loss_val = loss(data.cuda(),recon.cuda())
        loss_temp += loss_val.item()
        psnr_avg += 10 * math.log10(1 / loss_val.item())
        loss_val.backward()
        
        optimizer.step()
    end = time.time()
    print("====>Epoch {}\nLoss Average : {}\nTime     : {}\nAvg psnr    : {}"
          .format(i,
                  (loss_temp/len(train_loader)),
                  (end-start),
                  (psnr_avg/len(train_loader))
                  ))
print("Training Success")

Begin Training...
====>Epoch 0
Loss Average : 27073.606387867647
Time     : 22.47698998451233
Avg psnr    : -44.28853121446804
====>Epoch 1
Loss Average : 26881.010857077206
Time     : 22.405093669891357
Avg psnr    : -44.246248858347776
====>Epoch 2
Loss Average : 26921.963350183825
Time     : 25.60654592514038
Avg psnr    : -44.25904742069003
====>Epoch 3
Loss Average : 26838.340073529413
Time     : 22.032244443893433
Avg psnr    : -44.25715865755512
====>Epoch 4
Loss Average : 26783.391888786766
Time     : 22.204740285873413
Avg psnr    : -44.22858368322733
====>Epoch 5
Loss Average : 26999.466681985294
Time     : 21.931516647338867
Avg psnr    : -44.268763857148144


KeyboardInterrupt: 

In [8]:
temp = 0
for batch_idx,(data,_) in enumerate(train_loader):
    temp = temp + np.mean(data.numpy())
print(temp/len(train_loader.dataset))

temp = 0
for batch_idx,(data,_) in enumerate(test_loader):
    temp = temp + np.mean(data.numpy())
print(temp/len(test_loader.dataset))


    

0.01835903763771057
0.03779703378677368


In [9]:
for batch,(data,_) in enumerate(train_loader):
    print(data)
    break;

tensor([[[[0.1294, 0.1176, 0.1137,  ..., 0.2902, 0.2824, 0.2745],
          [0.1216, 0.1137, 0.1137,  ..., 0.2941, 0.2902, 0.2784],
          [0.1216, 0.1137, 0.1176,  ..., 0.2980, 0.2941, 0.2863],
          ...,
          [0.1216, 0.0863, 0.1255,  ..., 0.2667, 0.2588, 0.2510],
          [0.1451, 0.0784, 0.2157,  ..., 0.2706, 0.2588, 0.2549],
          [0.1647, 0.0745, 0.1765,  ..., 0.2745, 0.2667, 0.2627]]],


        [[[0.2627, 0.3882, 0.8471,  ..., 0.9490, 0.9490, 0.9451],
          [0.4275, 0.9059, 0.8980,  ..., 0.9529, 0.9529, 0.9490],
          [0.9255, 0.9216, 0.8275,  ..., 0.9412, 0.9412, 0.9490],
          ...,
          [0.0314, 0.0078, 0.0000,  ..., 0.4941, 0.4471, 0.5333],
          [0.0196, 0.0039, 0.0000,  ..., 0.5176, 0.4863, 0.5529],
          [0.0118, 0.0039, 0.0000,  ..., 0.5216, 0.4980, 0.5412]]],


        [[[0.3333, 0.3804, 0.2745,  ..., 0.1804, 0.1725, 0.1686],
          [0.3882, 0.2471, 0.2667,  ..., 0.1961, 0.1922, 0.2039],
          [0.3412, 0.3137, 0.2039,  ..