# image reconstruction

## Import libraries

In [1]:
import torch
import torchvision
from torch.utils.data import Dataset
from os import listdir
from os.path import join
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from math import log10
from tqdm import tqdm
import os
import random
import copy

## Load data

In [2]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Colab Notebooks/MachineLearning2022-2/MachineLearning2022-2

Mounted at /content/drive
/content/drive/MyDrive/Colab Notebooks/MachineLearning2022-2/MachineLearning2022-2


In [3]:
directory_data  = './'
filename_data   = 'assignment_11_data.npz'
data            = np.load(os.path.join(directory_data, filename_data))

clean_image_train = data['label_train']
input_image_test  = data['input_test'] 
clean_image_test  = data['label_test']

In [4]:
print(clean_image_train.shape)
print(input_image_test.shape)
print(clean_image_test.shape)

(800, 128, 128)
(400, 128, 128)
(400, 128, 128)


input_train 데이터셋 만들기

In [5]:
def idx_list(input_image_test, idx):
  #temp_input_test에서 0로 채워진 row 찾기
  col_idx_list = []
  row_idx_list = []
  #col
  col_idx_list.append([i for i, value in enumerate(input_image_test[idx][0,:]) if value == 0])
  #row
  row_idx_list.append([i for i, value in enumerate(input_image_test[idx][:,0]) if value == 0])
  
  idx_list = {'idx':idx, 'row': row_idx_list, 'col': col_idx_list}

  return idx_list

In [6]:
input_image_test[19]

array([[0.        , 0.        , 0.        , ..., 0.4920467 , 0.38449692,
        0.38937042],
       [0.        , 0.        , 0.        , ..., 0.44362608, 0.41021436,
        0.3614794 ],
       [0.        , 0.        , 0.        , ..., 0.53790532, 0.492324  ,
        0.50694449],
       ...,
       [0.        , 0.        , 0.        , ..., 0.21525404, 0.4119235 ,
        0.4495157 ],
       [0.        , 0.        , 0.        , ..., 0.25680498, 0.43713409,
        0.54537785],
       [0.        , 0.        , 0.        , ..., 0.66106588, 0.94790915,
        0.33974897]])

In [7]:
def generate_input_train(clean_image_train, idx_list, offset):
  #col과 row를 0으로 변환
  black = False
  idx = idx_list['idx']+offset
  row_idx_list = idx_list['row']
  col_idx_list = idx_list['col']

  if len(row_idx_list) > 50 or len(col_idx_list)>50:
    row_idx_list = [1,3,5,7,9]
    col_idx_list = [2,4,6,8,10]
    black = True

  input_train_single = copy.deepcopy(clean_image_train[idx])
  #row
  for i in row_idx_list:
    input_train_single[i] = 0
  #col
  for i in col_idx_list:
    input_train_single[:,i] = 0


  return input_train_single, black


In [8]:
test_blurr = []
for i in range(len(input_image_test)):
  test_blurr.append(idx_list(input_image_test, i))
  #print(test_blurr[i]['idx'])

In [9]:
test_blurr[19]

{'idx': 19,
 'row': [[0,
   1,
   2,
   3,
   4,
   5,
   6,
   7,
   8,
   9,
   10,
   11,
   12,
   13,
   14,
   15,
   16,
   17,
   18,
   19,
   20,
   21,
   22,
   23,
   24,
   25,
   26,
   27,
   28,
   29,
   30,
   31,
   32,
   33,
   34,
   35,
   36,
   37,
   38,
   39,
   40,
   41,
   42,
   43,
   44,
   45,
   46,
   47,
   48,
   49,
   50,
   51,
   52,
   53,
   54,
   55,
   56,
   57,
   58,
   59,
   60,
   61,
   62,
   63,
   64,
   65,
   66,
   67,
   68,
   69,
   70,
   71,
   72,
   73,
   74,
   75,
   76,
   77,
   78,
   79,
   80,
   81,
   82,
   83,
   84,
   85,
   86,
   87,
   88,
   89,
   90,
   91,
   92,
   93,
   94,
   95,
   96,
   97,
   98,
   99,
   100,
   101,
   102,
   103,
   104,
   105,
   106,
   107,
   108,
   109,
   110,
   111,
   112,
   113,
   114,
   115,
   116,
   117,
   118,
   119,
   120,
   121,
   122,
   123,
   124,
   125,
   126,
   127]],
 'col': [[0,
   1,
   2,
   3,
   4,
   5,
   6,
   17,
   18,
  

In [10]:
input_image_train=[]
black_list=[]
for i in test_blurr:
  (input, black) = generate_input_train(clean_image_train, i,0)
  if black == True:
    black_list.append(i)
  input_image_train.append(input)

for i in test_blurr:
  (input, black) = generate_input_train(clean_image_train, i,400)
  if black == True:
    black_list.append(i)
  input_image_train.append(input)

black_list

[]

에러 수정

## plot data

In [11]:
def plot_image(title, image,k):
    
    nRow = 10
    nCol = 10
    size = 3
    
    fig, axes = plt.subplots(nRow, nCol, figsize=(size * nCol, size * nRow))
    fig.suptitle(title, fontsize=16)

    #k=105
    for r in range(nRow):
        for c in range(nCol):
            #k = r * nCol * 10 + c * 4 + 10
            axes[r, c].imshow(image[k], cmap='gray', vmin=0, vmax=1)
            k+=1
    plt.tight_layout()
    plt.show()

In [12]:
def plot_image_temp(title, imag, k):
    
    nRow = 4
    nCol = 5
    size = 3
    
    fig, axes = plt.subplots(nRow, nCol, figsize=(size * nCol, size * nRow))
    fig.suptitle(title, fontsize=16)

    #k=105
    for r in range(nRow):
        for c in range(nCol):
            #k = r * nCol * 10 + c * 4 + 10
            axes[r, c].imshow(image[k], cmap='gray', vmin=0, vmax=1)
            k+=1
    plt.tight_layout()
    plt.show()

In [13]:
input_image_train[19]

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [None]:
plot_image('input image (testing)', input_image_train, 0)

In [None]:
plot_image_temp('input image (testing)', input_image_test)

TypeError: ignored

In [None]:
plot_image('label image (training)', clean_image_train)

In [None]:
plot_image('clean image (testing)', clean_image_test)

## custom data loader for the training data

In [None]:
class dataset(Dataset):
    
    def __init__(self, clean_image, input_image):
        
        self.clean_image = clean_image
        self.input_image = input_image

    def __getitem__(self, index):
        # ==================================================
        # modify the codes for training data
        #        
        clean_image = self.clean_image[index]
        clean_image = torch.FloatTensor(clean_image).unsqueeze(dim=0)

        input_image = self.input_image[index]
        input_image = torch.FloatTensor(input_image).unsqueeze(dim=0)


        return (clean_image, input_image) # trainingdata도 반환하게 하기, 영상 16:31
        #    
        # ==================================================

    def __len__(self):

        number_image = self.clean_image.shape[0]

        return number_image    

## setting device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
print(device)

## custom data loader for the testing data

In [None]:
input_image_test = torch.FloatTensor(input_image_test).unsqueeze(dim=1)
clean_image_test = torch.FloatTensor(clean_image_test).unsqueeze(dim=1)

input_image_test = input_image_test.to(device)
clean_image_test = clean_image_test.to(device)

## construct datasets and dataloaders for training and testing

In [None]:
# ==================================================
# determine the mini-batch size
#
size_minibatch      = 10
learning_rate = 0.01
number_epoch=500
#
# ==================================================

dataset_train       = dataset(clean_image_train, input_image_train)
dataloader_train    = torch.utils.data.DataLoader(dataset_train, batch_size=size_minibatch, shuffle=True, drop_last=True)

## construct a neural network 

In [None]:
#0~1 sigmoid, sigmoid 말고 다른 것 써도 됨

In [None]:
class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__()

        # -------------------------------------------------
        # Encoder
        # -------------------------------------------------
        self.e_layer1 = nn.Sequential(
                        nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=True),  
                        nn.MaxPool2d(2,2),
                        nn.ReLU(),
                        nn.BatchNorm2d(2),
        )
        
        self.e_layer2 = nn.Sequential(
                        nn.Conv2d(in_channels=2, out_channels=4, kernel_size=3, stride=1, padding=1, bias=True),
                        nn.MaxPool2d(2,2),
                        nn.ReLU(),
                        nn.BatchNorm2d(4),
        )

        # -------------------------------------------------
        # Decoder
        # -------------------------------------------------
        self.d_layer1 = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                        nn.Conv2d(in_channels=4, out_channels=2, kernel_size=3, stride=1, padding=1, bias=True),
                        nn.ReLU(),
                        nn.BatchNorm2d(2),
        )
        
        self.d_layer2 = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                        nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, stride=1, padding=1, bias=True),
                        nn.Sigmoid(),
        )

        # -------------------------------------------------
        # Network
        # -------------------------------------------------
        self.network = nn.Sequential(
                        self.e_layer1,
                        self.e_layer2,
                        self.d_layer1, 
                        self.d_layer2,
        )

        self.initialize_weight()

    def forward(self,x):
    
        out = self.network(x)
      
        return out

    # ======================================================================
    # initialize weights
    # ======================================================================
    def initialize_weight(self):
            
        for m in self.network.modules():
            
            if isinstance(m, nn.Conv2d):

                nn.init.xavier_uniform_(m.weight) 
                if m.bias is not None:

                    nn.init.constant_(m.bias, 1)
                    pass
                    
            elif isinstance(m, nn.BatchNorm2d):
                
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 1)
                
            elif isinstance(m, nn.Linear):

                nn.init.xavier_uniform_(m.weight)

                if m.bias is not None:
                    
                    nn.init.constant_(m.bias, 1)
                    pass

## build the network

In [None]:
model = Network().to(device)

# ==================================================
# determine the optimiser and its associated hyper-parameters
#

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

#
# ==================================================

## compute the prediction

In [None]:
def compute_prediction(model, input):

    prediction = model(input)

    return prediction

compute the loss

In [None]:
def compute_loss(prediction, label):
    # ==================================================
    # fill up the blank
    #
    criterion   = nn.MSELoss()
    # 
    # ==================================================
    loss        = criterion(prediction, label)

    return loss

compute the accurcy

In [None]:
def compute_accuracy(prediction, label):

    prediction  = prediction.squeeze(axis=1)
    label       = label.squeeze(axis=1)
    mse_loss    = torch.mean((prediction - label) ** 2)

    if mse_loss == 0.0:
        psnr = 100
    else:
        psnr = 10 * torch.log10(1 / mse_loss)

    psnr = psnr.item()
    
    return psnr

## compute the PSNR metric

- data1 : mini-batch-size x channel x height x width (torch tensor)
- data2 : mini-batch-size x channel x height x width (torch tensor)

In [None]:
def compute_psnr(data1, data2):

    mse         = nn.MSELoss()(data1, data2)
    mse_value   = mse.item()
    psnr        = 10 * np.log10(1 / mse_value)

    return psnr

## Variable for the learning curves

In [None]:
psnr_test = np.zeros(number_epoch)

## train

In [None]:
def train(model, optimizer, dataloader):

    model.train()

    # ==================================================
    # fill up the blank
    #
    loss_epoch      = []
    accuracy_epoch  = []

    model.train()

    for index_batch, (original, blur) in enumerate(dataloader):

        original    = original.to(device)
        blur        = blur.to(device)
        
        prediction  = compute_prediction(model, blur)
        loss        = compute_loss(prediction, original)
        accuracy    = compute_accuracy(prediction, original)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_epoch.append(loss.item())
        accuracy_epoch.append(accuracy)

    loss_mean_epoch     = np.mean(loss_epoch)
    loss_std_epoch      = np.std(loss_epoch)

    accuracy_mean_epoch = np.mean(accuracy_epoch)
    accuracy_std_epoch  = np.std(accuracy_epoch)

    loss        = {'mean' : loss_mean_epoch, 'std' : loss_std_epoch}
    accuracy    = {'mean' : accuracy_mean_epoch, 'std' : accuracy_std_epoch}

    return (loss, accuracy)  



    #
    # ==================================================

## test

In [None]:
def test(model, input_image, clean_image):

    model.eval()

    num_steps   = 40
    steps       = np.linspace(0, input_image.shape[0], num_steps+1).astype(int)
    psnr_steps  = np.zeros(num_steps)
    
    for i in range(num_steps):
        input           = input_image[steps[i]:steps[i+1], :, :, :]
        clean           = clean_image[steps[i]:steps[i+1], :, :, :]
        prediction      = compute_prediction(model, input)
        psnr_steps[i]   = compute_psnr(clean, prediction)
    
    psnr = psnr_steps.mean()
    
    return psnr

## train and test

In [None]:
# ================================================================================
# 
# iterations for epochs
#
# ================================================================================
for i in tqdm(range(number_epoch)):
    
    # ================================================================================
    # 
    # training
    #
    # ================================================================================
    train(model, optimizer, dataloader_train)

    # ================================================================================
    # 
    # testing
    #
    # ================================================================================
    psnr            = test(model, input_image_test, clean_image_test)
    psnr_test[i]    = psnr

---

## functions for presenting the results

---

In [None]:
def function_result_01():
    
    title           = 'psnr (testing)'
    label_axis_x    = 'epoch' 
    label_axis_y    = 'psnr'
    
    plt.figure(figsize=(8, 6))
    plt.title(title)

    plt.plot(psnr_test, '-')
    
    plt.xlabel(label_axis_x)
    plt.ylabel(label_axis_y)

    plt.tight_layout()
    plt.show()

In [None]:
def function_result_02():
    
    nRow = 9
    nCol = 4
    size = 3
    
    title = 'testing results'
    fig, axes = plt.subplots(nRow, nCol, figsize=(size * nCol, size * nRow))
    fig.suptitle(title, fontsize=16)

    model.eval()

    prediction  = compute_prediction(model, input_image_test)

    input_image = input_image_test.detach().cpu().squeeze(axis=1)
    clean_image = clean_image_test.detach().cpu().squeeze(axis=1)
    prediction  = prediction.detach().cpu().squeeze(axis=1)

    nStep = 3

    for r in range(3):
        for c in range(nCol):
            k = r * nCol * 10 + c * 4 + 10
            axes[0 + r * nStep, c].imshow(input_image[k], cmap='gray')
            axes[1 + r * nStep, c].imshow(clean_image[k], cmap='gray', vmin=0, vmax=1)
            axes[2 + r * nStep, c].imshow(prediction[k], cmap='gray', vmin=0, vmax=1)

            axes[0 + r * nStep, c].xaxis.set_visible(False)
            axes[1 + r * nStep, c].xaxis.set_visible(False)
            axes[2 + r * nStep, c].xaxis.set_visible(False)
            
            axes[0 + r * nStep, c].yaxis.set_visible(False)
            axes[1 + r * nStep, c].yaxis.set_visible(False)
            axes[2 + r * nStep, c].yaxis.set_visible(False)

    plt.tight_layout()
    plt.show()

In [None]:
def function_result_03():
    
    print('final testing psnr = %9.8f' % (psnr_test[-1]))

---

## results 

---

In [None]:
number_result = 3 

for i in range(number_result):

    title           = '# RESULT # {:02d}'.format(i+1)
    name_function   = 'function_result_{:02d}()'.format(i+1)

    print('') 
    print('################################################################################')
    print('#') 
    print(title)
    print('#') 
    print('################################################################################')
    print('') 

    eval(name_function)