In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import glob
import random
import rasterio
import numpy as np
import classic_algos.bicubic_interpolation as bicubic

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device )

cuda


In [None]:
class SatelliteSRDataset(Dataset):
    def __init__(self, root_dir, scale_factor=3, hr_patch_size=144, augment=True):
        super().__init__()
        self.root_dir = root_dir
        self.scale_factor = scale_factor
        self.hr_patch_size = hr_patch_size
        self.augment = augment

        self.file_paths = glob.glob(os.path.join(root_dir, '**', '*.tif'), recursive=True)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        with rasterio.open(img_path) as src:
            image = src.read()

        image = image.astype(np.float32) / 255.0

        #random crop
        c, h, w = image.shape
        top = random.randint(0, h - self.hr_patch_size)
        left = random.randint(0, w - self.hr_patch_size)
        hr_patch_np = image[:, top:top+self.hr_patch_size, left:left+self.hr_patch_size]

        if self.augment:
            # Flip Horizontal
            if random.random() < 0.5:
                hr_patch_np = np.flip(hr_patch_np, axis=2)

            # Flip Vertical
            if random.random() < 0.5:
                hr_patch_np = np.flip(hr_patch_np, axis=1)

            # rotations
            k = random.choice([0, 1, 2, 3])
            if k > 0:
                hr_patch_np = np.rot90(hr_patch_np, k, axes=(1, 2))

        hr_patch_np = np.ascontiguousarray(hr_patch_np)

        hr_tensor = torch.from_numpy(hr_patch_np)

        lr_size = self.hr_patch_size // self.scale_factor
        lr_numpy = bicubic.SR_bicubic(hr_patch_np, lr_size, lr_size)
        lr_tensor = torch.from_numpy(lr_numpy)

        return lr_tensor, hr_tensor

In [None]:
class FSRCNN(nn.Module):
    def __init__(self):
        """
        d: filters dimension
        s: shrinking dimension
        m: mapping layers number
        """
        d, s, m = 56, 12, 4
        scale_factor = 3
        channels = 3
        super(FSRCNN, self).__init__()

        self.feature_extraction = nn.Sequential(
            nn.Conv2d(channels, d, kernel_size=5, padding=2),
            nn.PReLU(d)
        )

        self.shrink = nn.Sequential(
            nn.Conv2d(d, s, kernel_size=1),
            nn.PReLU(s)
        )

        map_layers = []
        for _ in range(m):
            map_layers.extend([nn.Conv2d(s, s, kernel_size=3, padding=1),
                               nn.PReLU(s)
            ])
        self.mapping = nn.Sequential(*map_layers)

        self.expand = nn.Sequential(
            nn.Conv2d(s, d, kernel_size=1),
            nn.PReLU(d)
        )

        self.deconv = nn.ConvTranspose2d(d, channels,
                                         kernel_size=9,
                                         stride=3,
                                         padding=3,
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.feature_extraction(x)
        x = self.shrink(x)
        x = self.mapping(x)
        x = self.expand(x)
        x = self.deconv(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
