In [None]:
!pip install torchviz
!pip install graphviz
!pip install torchview

!pip install torchinfo



In [None]:

import os
import math
import random

import cv2
import numpy as np
import time

import matplotlib.pyplot as plt

import os
from PIL import Image

import torch
from torch.nn import init
from torch import nn, optim
from torchinfo import summary
from torch.nn import functional as F
from torchvision.io import read_image
from torch.nn.utils import spectral_norm
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
from torchviz import make_dot
from torchview import draw_graph


import warnings
warnings.filterwarnings("ignore")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f'Device: {device}')

Device: cuda


In [None]:
class SRDataset(Dataset):
    def __init__(self, path, mode='train'):
        if mode == 'train':
            self.hr_path = os.path.join(path, 'hr_crop')
            self.lr_path = os.path.join(path, 'lr_crop')
        elif mode == 'test':
            #self.hr_path = os.path.join(path, 'hr')
            self.lr_path = os.path.join(path, 'lr')

        # Filter to include only .png files
        #self.hr_images = [file for file in sorted(os.listdir(self.hr_path)) if file.endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
        self.lr_images = [file for file in sorted(os.listdir(self.lr_path)) if file.endswith(('.png', '.jpg', '.jpeg', '.bmp'))]

        self.mode = mode

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, index):
        #hr_image_filename = self.hr_images[index]
        lr_image_filename = self.lr_images[index]
        #hr_image_path = os.path.join(self.hr_path, hr_image_filename)
        lr_image_path = os.path.join(self.lr_path, lr_image_filename)
        #hr_image = read_image(hr_image_path) / 255.
        lr_image = read_image(lr_image_path) / 255.

        if self.mode == 'train':
            # Apply random flip augmentation
            if random.random() > 0.5:
                angle = random.choice([90, 180, 270])
                hr_image = TF.rotate(hr_image, angle)
                lr_image = TF.rotate(lr_image, angle)

            # Apply RGB channel permutation
            if random.random() > 0.5:
                channels = torch.randperm(3)
                hr_image = hr_image[channels]
                lr_image = lr_image[channels]

        # Rotate lr_image by 1, 2, -1 degrees
        lr_added_1 = torch.clamp((lr_image * 255 + 1), max=255) / 255.
        lr_rotated_2 = TF.rotate(lr_image, 2)
        lr_rotated_neg_1 = TF.rotate(lr_image, -1)

        # Concatenate original and rotated images along the channel dimension
        lr_image_augmented = torch.cat((lr_image, lr_added_1, lr_rotated_2, lr_rotated_neg_1), dim=0)

        #return hr_image, lr_image_augmented
        return lr_image_augmented

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#%cd drive/My\ Drive/Colab Notebooks/MISR_CNN
#!ls

In [None]:
batch_size = 8

val_path = '/content/drive/MyDrive//Colab Notebooks/MISR_CNN/Allimages/LRtest_imgs'

val_dataset = SRDataset(val_path, mode='test')

val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
                          num_workers=4)

In [None]:
def make_layer(basic_block, num_basic_block, **kwargs):
    """Make layers by stacking the same blocks."""
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwargs))
    return nn.Sequential(*layers)

class DenseBlock(nn.Module):
    """Dense Block."""
    def __init__(self, embed_dim=64):
        super(DenseBlock, self).__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.conv1 = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        self.conv2 = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        self.conv3 = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        res = x.clone()  # Use clone to avoid modifying the input directly
        x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(x + x1))
        x3 = self.lrelu(self.conv3(x + x1 + x2))
        return x3*0.3 + res

class RDB(nn.Module):
    """Residual Dense Block."""
    def __init__(self, embed_dim):
        super(RDB, self).__init__()
        self.rdb1 = DenseBlock(embed_dim)
        self.rdb2 = DenseBlock(embed_dim)
        self.rdb3 = DenseBlock(embed_dim)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return out*0.3+ x

class Upsample(nn.Sequential):
    """Upsample module."""
    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # Check if scale is a power of 2
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
        super(Upsample, self).__init__(*m)

class CNN(nn.Module):
    """
    Args:
        num_in_ch (int): Channel number of inputs.
        num_out_ch (int): Channel number of outputs.
        embed_dim (int): Channel number of intermediate features.
            Default: 64
        num_block (int): Block number in the trunk network. Defaults: 23
    """

    def __init__(self, num_in_ch, num_out_ch, upscale=4, embed_dim=64, num_block=16, num_final_feat=64):
        super(CNN, self).__init__()
        self.upscale = upscale

        # ------------------------- 1, shallow feature extraction ------------------------- #
        self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)

        # ------------------------- 2, deep feature extraction ------------------------- #
        self.body = make_layer(RDB, num_block, embed_dim=embed_dim)
        self.conv_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # ------------------------- 3, high quality image reconstruction ------------------------- #
        self.conv_before_upsample = nn.Sequential(
            nn.Conv2d(embed_dim, num_final_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
        self.upsample = Upsample(upscale, num_final_feat)
        self.conv_last = nn.Conv2d(num_final_feat, num_out_ch, 3, 1, 1)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        feat = self.conv_first(x)

        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat

        # upsample
        feat = self.conv_before_upsample(feat)
        out = self.conv_last(self.upsample(feat))
        return out

model = CNN(num_in_ch=12, num_out_ch=3).to(device)


In [None]:
output_dir = '/content/drive/MyDrive//Colab Notebooks/MISR_CNN/Output/sr_prediction'
model_path = '/content/drive/MyDrive//Colab Notebooks/MISR_CNN/models/*.pth' # replace '*' with filename of desired .pth file

if not os.path.exists(output_dir):
        os.makedirs(output_dir)

model.load_state_dict(torch.load(model_path))

model.eval()  # Set model to evaluation mode
with torch.no_grad():  # No gradients needed
    for i, (low_res) in enumerate(val_loader):
        low_res = low_res.to(device)
        super_res = model(low_res)  # Generate high-resolution output from model

        # Saving each image pair: low_res and super_res
        save_image(low_res[:, :3].clamp(0, 1), f'{output_dir}/planet_low_res_sample_{i+1}.png', normalize=False)
        save_image(super_res.clamp(0, 1), f'{output_dir}/planet_super_res_sample_{i+1}.png', normalize=False)