In [2]:
print("Importing Library...")
import torch
import torchvision
import torch.utils.data as putils

from torch import nn, optim
from torchvision import datasets,transforms
from torch.autograd import Variable
from PIL import Image
import numpy
import numpy as np
import math

print("Importing Library Success")


Importing Library...
Importing Library Success


In [3]:
print("Defining Class...")
class ComCNN(nn.Module):
    def __init__(self, channel):
        super(ComCNN, self).__init__()
        self.conv1 = nn.Conv2d(channel, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64, affine=False)
        self.conv3 = nn.Conv2d(64, channel, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.relu(self.bn1(self.conv2(out)))
        return self.conv3(out)

class MinMaxCNN(nn.Module):
    def __init__(self, channel,interpolate_size=2,mode='bicubic',deep=3):
        super(MinMaxCNN, self).__init__()
        self.deconv1 = nn.Conv2d(channel, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64, affine=False)
        self.deconv_n = nn.Conv2d(in_channels=64, out_channels=64,kernel_size=3, padding=1)
        self.bn_n = nn.BatchNorm2d(64, affine=False)
        self.deconv3 = nn.ConvTranspose2d(64, 2, kernel_size=3, padding=1)
        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=1,padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.interpolate_size = interpolate_size
        self.mode = mode

    def forward(self, x):
        # out = self.interpolate(x)
        out = x
        out = self.relu(self.deconv1(x))
        out = self.relu(self.bn_n(self.deconv_n(out)))
        out = (self.deconv3(out))
        out = self.maxpool(out)
        return self.sigmoid(out)
    
    def interpolate(self,x):
        return nn.functional.interpolate(input=x, scale_factor=self.interpolate_size, mode=self.mode,
                                         align_corners=False)

class BitmapRecCNN(nn.Module):
    def __init__(self, channel,interpolate_size=2,mode='bicubic',deep=3):
        super(BitmapRecCNN, self).__init__()
        self.deconv1 = nn.Conv2d(channel, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64, affine=False)
        self.deconv_n = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn_n = nn.BatchNorm2d(64, affine=False)
        self.deconv3 = nn.ConvTranspose2d(64, channel, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        self.interpolate_size = interpolate_size
        self.mode = mode
        self.deep = 5

    def forward(self, x):
        x = self.interpolate(x)
        out = self.relu(self.deconv1(x))
        for _ in range(self.deep):
            out = self.relu(self.bn_n(self.deconv_n(out)))
        out = self.sigmoid(self.deconv3(out))
        rounding = self.rounding_layer(out)
        return rounding

    def interpolate(self,x):
        return nn.functional.interpolate(input=x, scale_factor=self.interpolate_size, mode=self.mode,
                                         align_corners=False)

    def rounding_layer(self,batch_image):
        return batch_image.round()
    
    
class Network(nn.Module):
    def __init__(self,comCNN,bitmapCNN,minmaxCNN):
        super(Network,self).__init__()
        self.first = comCNN
        self.second = bitmapCNN
        self.third = minmaxCNN
        
    def forward(self, x):
        compact_repre = self.first(x.cuda())
        # print("repre : {}".format(compact_repre.shape))
        bitmap_repre = self.second(compact_repre.cuda())
        # print("bitmap : {}".format(bitmap_repre.shape))
        minmax_repre = self.third(compact_repre.cuda())
        # print("minmax : {}".format(minmax_repre.shape))
        recon = self.btc(bitmap_repre,minmax_repre)
        # print(recon.shape)
        
        
        return recon,minmax_repre,bitmap_repre,compact_repre
    
    def btc(self,bitmap,minmax):
        result = bitmap.clone()
        for i in range(len(minmax)):
            x = int(len(bitmap[i][0])/len(minmax[0][0]))
            y = int(len(bitmap[i][0])/len(minmax[0][0]))
            temp_output = bitmap[i][0].clone()
            block_image = torch.split(torch.cat(torch.split(bitmap[i][0], y, dim=1)), x)
            # min = numpy.split(numpy.concatenate(numpy.split(minmax[i][0],4,axis=1)),4)
            # max = numpy.split(numpy.concatenate(numpy.split(minmax[i][1],4,axis=1)),4)
            min = minmax[i][0].transpose(0,1).flatten()
            max = minmax[i][1].transpose(0,1).flatten()
            
            # print(block_image[0])
            # block_image[0][block_image[0] == 1] = 99
            # print(block_image[0])
            # 
            # print(len(min))
            # print(len(max))
            # print(len(block_image))
            
            for j in range(len(min)):
                block_image[j][block_image[j] == 1] = max[j]
                block_image[j][block_image[j] == 0] = min[j]
            
            # print(len(min))
            # print(len(max))
            # print(min)
            # print(max)
            
            block_image = torch.cat(block_image)
            temp_split = torch.split(block_image,int(len(bitmap[i][0])))
            temp_cat = torch.cat(temp_split,dim=1)
            
            result[i][0]  = temp_cat
            
            # print(min[0])
            # print(block_image)
            # block_image_merge = torch.cat(block_image,out=temp_output)
            # print(block_image_merge)
            # print(block_image_merge.shape)
            # print(temp_output)
            
        return result



def loss(original_image,reconstructed_image):
    return torch.nn.MSELoss(size_average=False)(reconstructed_image,original_image)

def psnr(img1, img2):
    mse = numpy.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 1
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

print("Defining Class Succes")
    

Defining Class...
Defining Class Succes


In [4]:
print("Load Image Dataset")
train_image_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.CenterCrop(size=(32,32)),
    transforms.ToTensor()
    # transforms.Normalize((0.5,),(0.5,))
])

test_image_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.CenterCrop(size=(32,32)),
    transforms.ToTensor()
    # transforms.Normalize((0.5,),(0.5,))
])


train_path = '../dataset/train/'
test_path = '../dataset/test/'
train_dataset = torchvision.datasets.ImageFolder(
    root=train_path,
    transform=train_image_transform
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32,
                                          shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.ImageFolder(
    root=test_path,
    transform=test_image_transform
)

test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=32,
                                          shuffle=True,num_workers=2)

print("Load Image Dataset Success")



Load Image Dataset
Load Image Dataset Success


In [7]:
print("Initialize Model....")
CUDA = torch.cuda.is_available()

if CUDA:
    comCNN = ComCNN(1).cuda()
    bitmapRecCNN = BitmapRecCNN(1).cuda()
    minMaxCNN = MinMaxCNN(1).cuda()
    network = Network(bitmapCNN=bitmapRecCNN,comCNN=comCNN,minmaxCNN=minMaxCNN).cuda()
    print("Cuda is available, using gpu instead")
else:
    comCNN = ComCNN(1)
    bitmapRecCNN = BitmapRecCNN(1)
    minMaxCNN = MinMaxCNN(1).cuda()
    network = Network(bitmapCNN=bitmapRecCNN,comCNN=comCNN,minmaxCNN=minMaxCNN)
    print("Cuda is not available, using cpu instead")

optimizer = optim.Adam(network.parameters(),lr=1e-3)


print("Initialize Model Success")



Initialize Model....
Cuda is available, using gpu instead
Initialize Model Success


In [8]:
import time
print("Begin Training...")
epoch = 5
network.train()
for i in range(epoch):
    loss_temp = 0
    psnr_avg = 0
    start = time.time()
    for batch_idx,(data,_) in enumerate(train_loader):
        optimizer.zero_grad()
        data = Variable(data)
        recon,minmax,bitmap,compact = network(data.cuda())
        
        # print("min")
        # print(minmax[0][0][0][0].detach().cpu().numpy())
        # 
        # print("max")
        # print(minmax[0][1][0][0].detach().cpu().numpy())
        # 
        # print(bitmap[0][0][0:2,0:2].detach().cpu().numpy())
        # print(recon[0][0][0:2,0:2])
        # break
        loss_val = loss(data.cuda(),recon)
        loss_temp += loss_val.item()/len(data)
        loss_val.backward()
        
        optimizer.step()
        
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                i, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss_val.item() / len(data)))
    # break
    end = time.time()
    print("====>Epoch {}\nLoss Average : {}\nTime     : {}\nAvg psnr    : {}"
          .format(i,
                  (loss_temp/len(train_loader)),
                  (end-start),
                  (psnr_avg/len(train_loader))
                  ))
print("Training Success")

Begin Training...


KeyboardInterrupt: 