In [1]:
import numpy as np
# import matplotlib.pyplot as plt
# import imageio
import cv2
import os
# import PIL

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# import torchvision.models as models
from torch.utils.data import DataLoader, TensorDataset
# from torchvision import datasets
# from torchvision import transforms
from tqdm import tqdm 

In [2]:
X = []
y = []
def load_images_from_directory(directory):
    images = []
    for filename in os.listdir(directory):
        filepath = os.path.join(directory, filename)
        img = cv2.imread(filepath)
        if img is not None:
            images.append(img)
    return images

X = load_images_from_directory('./Train/low')
y = load_images_from_directory('./Train/high')

In [3]:
X = np.array(X)
y = np.array(y)

In [4]:
X = X/255
y = y/255

X.shape
y.shape

(485, 400, 600, 3)

In [5]:
X

array([[[[0.01568627, 0.01960784, 0.02352941],
         [0.01568627, 0.01960784, 0.03921569],
         [0.01960784, 0.02745098, 0.01960784],
         ...,
         [0.04705882, 0.05490196, 0.05490196],
         [0.05098039, 0.05098039, 0.05490196],
         [0.05490196, 0.05490196, 0.05490196]],

        [[0.00392157, 0.01960784, 0.01176471],
         [0.00784314, 0.01960784, 0.01960784],
         [0.00784314, 0.01568627, 0.00784314],
         ...,
         [0.05098039, 0.05098039, 0.05882353],
         [0.05490196, 0.05882353, 0.04705882],
         [0.04313725, 0.05882353, 0.05490196]],

        [[0.01176471, 0.01176471, 0.01568627],
         [0.02352941, 0.01960784, 0.01568627],
         [0.01960784, 0.02352941, 0.01176471],
         ...,
         [0.0627451 , 0.05882353, 0.05882353],
         [0.05490196, 0.05098039, 0.03921569],
         [0.05098039, 0.0627451 , 0.07058824]],

        ...,

        [[0.01568627, 0.02352941, 0.01960784],
         [0.00784314, 0.01568627, 0.00784314]

In [6]:
X = X.transpose(0, 3, 1, 2)
X = torch.tensor(X, dtype=torch.float32).cuda()
y = y.transpose(0, 3, 1, 2)
y = torch.tensor(y, dtype=torch.float32).cuda()

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)#, random_state=42)

In [8]:
batch_size = 16
combined_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)

In [9]:
class CNNBlock(nn.Module):
    def __init__(self, filters):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv2d(3, filters, kernel_size=3, padding=1)
        self.sp = nn.Softplus()

    def forward(self, x):
        out = self.conv(x)
        out = self.sp(out)
        return out

class ResidualBlock(nn.Module):
    def __init__(self, filters):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(filters, filters, kernel_size=3, padding=1)
        self.sp1 = nn.Softplus()
        self.sp2 = nn.Softplus()
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.sp1(out)
        out = self.conv2(out)
        out += residual  
        out = self.sp2(out)
        return out

class CAM(nn.Module):
    def __init__(self, channels):
        super(CAM, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // 16, 1, bias=False),
            nn.Softplus(),
            nn.Conv2d(channels // 16, channels, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.global_avg_pool(x)
        y = self.fc(y)
        return x * y

class PyramidPooling(nn.Module):
    def __init__(self, pool_size, output_size):
        super(PyramidPooling, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(pool_size)
        self.output_size = output_size

    def forward(self, x):
        pooled = self.pool(x)
        upsampled = F.interpolate(pooled, size=self.output_size, mode='bilinear', align_corners=True)
        return upsampled

class KSM(nn.Module):
    def __init__(self):
        super(KSM, self).__init__()
        self.conv = nn.Conv2d(36, 3, kernel_size=1, padding=0)

    def forward(self, x):
        x = self.conv(x)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_block1 = CNNBlock(filters=32)
        self.residual_block = ResidualBlock(filters=32)
        self.cam = CAM(channels=32)
        self.conv2 = nn.Conv2d(32, 3, kernel_size=3, padding=1)
        self.pyramid1 = PyramidPooling(pool_size=1, output_size=(400, 600))
        self.pyramid2 = PyramidPooling(pool_size=2, output_size=(400, 600))
        self.pyramid3 = PyramidPooling(pool_size=4, output_size=(400, 600))
        self.pyramid4 = PyramidPooling(pool_size=8, output_size=(400, 600))
        self.pyramid5 = PyramidPooling(pool_size=16, output_size=(400, 600))
        self.ksm = KSM()
        self.final_conv = nn.Conv2d(3, 3, kernel_size=1, padding=0)
        self.sp = nn.Softplus()

    def forward(self, x):
        C1 = self.conv_block1(x)
        C1 = self.residual_block(C1) 
        cam = self.cam(C1)
        C2 = self.sp(self.conv2(cam))
        concat1 = torch.cat([C2, x], dim=1)  

        p1 = self.pyramid1(concat1)
        p2 = self.pyramid2(concat1)
        p3 = self.pyramid3(concat1)
        p4 = self.pyramid4(concat1)
        p5 = self.pyramid5(concat1)

        concat2 = torch.cat([p1, p2, p3, p4, p5, concat1], dim=1) 

        ksm = self.ksm(concat2)
        out = self.final_conv(ksm)
        return out

    
model = Net().cuda()  


In [10]:
model.load_state_dict(torch.load('./nnweights3.pth'))
model.eval()

Net(
  (conv_block1): CNNBlock(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sp): Softplus(beta=1, threshold=20)
  )
  (residual_block): ResidualBlock(
    (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sp1): Softplus(beta=1, threshold=20)
    (sp2): Softplus(beta=1, threshold=20)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (cam): CAM(
    (global_avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Conv2d(32, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): Softplus(beta=1, threshold=20)
      (2): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (3): Sigmoid()
    )
  )
  (conv2): Conv2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pyramid1): PyramidPooling(
    (pool): AdaptiveAvgPool2d(output_size=1)
  )
  (pyramid2): PyramidPooling(
    (pool): AdaptiveAvgPool2d(output_size=2)
  )
  (pyramid3): PyramidP

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=5e-4)

In [None]:
num_epochs = 20
for epoch in tqdm(range(num_epochs)):
    for batch_idx, (images, y_batched) in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.cuda()
        y_batched = y_batched.cuda()
        recon_images = model(images)
        loss = criterion(recon_images, y_batched)
        loss.backward()
        optimizer.step()
        print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item()}")

In [11]:
batch_size = 16
combined_dataset_test = TensorDataset(X_test, y_test)
testing_loader = DataLoader(combined_dataset_test, batch_size=batch_size, shuffle=False)

In [None]:
total_psnr = 0.0
num_batches = len(testing_loader)

with torch.no_grad():
    for images, labels in testing_loader:
        reconstructed_images = model(images)
        mse = nn.MSELoss(reduction='mean')(reconstructed_images, labels)
        psnr = 10 * torch.log10(1.0 / mse)
        total_psnr += psnr.item()

average_psnr = total_psnr / num_batches
print(f"Average PSNR on testing data: {average_psnr:.2f} dB")

In [44]:
reconstructed_images

tensor([[[[-0.0058,  0.0960,  0.1520,  ...,  0.1294,  0.0022, -0.5081],
          [-0.1771, -0.3156, -0.4596,  ..., -0.4916, -0.4431, -0.8113],
          [-0.2581, -0.4313, -0.4828,  ..., -0.4156, -0.2678, -0.6241],
          ...,
          [-0.2992, -0.5388, -0.6537,  ..., -0.5485, -0.3424, -0.6676],
          [-0.3849, -0.6233, -0.8326,  ..., -0.7141, -0.4129, -0.7493],
          [-0.8647, -1.0609, -1.1479,  ..., -1.1159, -0.9239, -1.0833]],

         [[ 0.8457,  1.0505,  1.1235,  ...,  1.1058,  0.9501,  0.3180],
          [ 0.7120,  0.5600,  0.3822,  ...,  0.3571,  0.4119, -0.0416],
          [ 0.6137,  0.4117,  0.3499,  ...,  0.4494,  0.6263,  0.1927],
          ...,
          [ 0.5550,  0.2739,  0.1279,  ...,  0.2568,  0.5116,  0.1103],
          [ 0.4392,  0.1695, -0.0946,  ...,  0.0482,  0.4176,  0.0082],
          [-0.1599, -0.3893, -0.4876,  ..., -0.4480, -0.2136, -0.4106]],

         [[ 1.7732,  2.1121,  2.2089,  ...,  2.1945,  1.9983,  1.1914],
          [ 1.6806,  1.5047,  