In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import glob
import random
import rasterio
import numpy as np
import classic_algos.bicubic_interpolation as bicubic
import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device )

cuda


In [9]:
class SatelliteSRDataset(Dataset):
    def __init__(self, root_dir, scale_factor=3, hr_patch_size=144, augment=True):
        super().__init__()
        self.root_dir = root_dir
        self.scale_factor = scale_factor
        self.hr_patch_size = hr_patch_size
        self.augment = augment

        self.file_paths = glob.glob(os.path.join(root_dir, '**', '*.tif'), recursive=True)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        with rasterio.open(img_path) as src:
            image = src.read()

        image_norm = image.astype(np.float32) / 255.0

        #random crop
        c, h, w = image_norm.shape
        top = random.randint(0, h - self.hr_patch_size)
        left = random.randint(0, w - self.hr_patch_size)
        hr_patch_np = image_norm[:, top:top+self.hr_patch_size, left:left+self.hr_patch_size]
        if self.augment:
            # Flip Horizontal
            if random.random() < 0.5:
                hr_patch_np = np.flip(hr_patch_np, axis=2)

            # Flip Vertical
            if random.random() < 0.5:
                hr_patch_np = np.flip(hr_patch_np, axis=1)

            # rotations
            k = random.choice([0, 1, 2, 3])
            if k > 0:
                hr_patch_np = np.rot90(hr_patch_np, k, axes=(1, 2))

        hr_patch_np = np.ascontiguousarray(hr_patch_np)
        hr_tensor = torch.from_numpy(hr_patch_np)

        hr_patch_hwc = image[:, top:top+self.hr_patch_size, left:left+self.hr_patch_size].transpose(1, 2, 0)
        lr_size = self.hr_patch_size // self.scale_factor
        lr_numpy_hwc = bicubic.SR_bicubic(hr_patch_hwc, lr_size, lr_size)
        lr_numpy = np.transpose(lr_numpy_hwc, (2, 0, 1)).astype(np.float32) / 255.0
        lr_tensor = torch.from_numpy(lr_numpy)

        return lr_tensor, hr_tensor

In [10]:
class FSRCNN(nn.Module):
    def __init__(self):
        """
        d: filters dimension
        s: shrinking dimension
        m: mapping layers number
        """
        d, s, m = 56, 12, 4
        scale_factor = 3
        channels = 3
        super(FSRCNN, self).__init__()

        self.feature_extraction = nn.Sequential(
            nn.Conv2d(channels, d, kernel_size=5, padding=2),
            nn.PReLU(d)
        )

        self.shrink = nn.Sequential(
            nn.Conv2d(d, s, kernel_size=1),
            nn.PReLU(s)
        )

        map_layers = []
        for _ in range(m):
            map_layers.extend([nn.Conv2d(s, s, kernel_size=3, padding=1),
                               nn.PReLU(s)
            ])
        self.mapping = nn.Sequential(*map_layers)

        self.expand = nn.Sequential(
            nn.Conv2d(s, d, kernel_size=1),
            nn.PReLU(d)
        )

        self.deconv = nn.ConvTranspose2d(d, channels,
                                         kernel_size=9,
                                         stride=3,
                                         padding=3,
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.feature_extraction(x)
        x = self.shrink(x)
        x = self.mapping(x)
        x = self.expand(x)
        x = self.deconv(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=0.001)


train

In [11]:
def PSNR(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    return 10 * torch.log10(1 / (mse + 1e-20))

In [12]:
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import math
from pathlib import Path
BATCH_SIZE = 64
EPOCHS = 100

TRAIN_DIR = Path.home() / '.data' / 'UCMerced_LandUse_Split' / 'train'
VAL_DIR = Path.home() / '.data' / 'UCMerced_LandUse_Split' / 'val'
MODEL_DIR = Path.home() / '.data' / 'fsrcnn_models'
MODEL_DIR.mkdir(parents=True, exist_ok=True)


In [13]:
train_dataset = SatelliteSRDataset(TRAIN_DIR)
val_dataset = SatelliteSRDataset(VAL_DIR, augment=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [14]:
model = FSRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

In [15]:
loss = []
psnr_metric = []

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for lr_imgs, hr_imgs in tqdm(train_loader):
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)

        optimizer.zero_grad()
        sr_imgs = model(lr_imgs)
        batch_loss = criterion(sr_imgs, hr_imgs)
        batch_loss.backward()
        optimizer.step()
        epoch_loss += batch_loss.item()
    avg_loss = epoch_loss / len(train_loader)
    loss.append(avg_loss)

    if epoch % 5 != 0:
        continue

    model.eval()
    epoch_psnr = 0
    with torch.no_grad():
        for lr_imgs, hr_imgs in val_loader:
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)
            sr_imgs = model(lr_imgs)
            epoch_psnr += PSNR(sr_imgs, hr_imgs).item()
    avg_psnr = epoch_psnr / len(val_loader)
    psnr_metric.append(avg_psnr)
    print(f"Epoch {epoch}, Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.2f} dB")
    torch.save(model.state_dict(), MODEL_DIR / f'fsrcnn_epoch_{epoch}.pth')

print("Training complete.")

ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

In [None]:
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(loss, label='Train Loss (MSE)')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(psnr_metric, label='Val PSNR', color='orange')
plt.title('Метрика качества (PSNR)')
plt.xlabel('Epochs')
plt.ylabel('dB')
plt.legend()
plt.grid(True)

plt.show()