In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import ee
ee.Authenticate()

In [3]:
ee.Initialize(project="gee-sr-project")

In [4]:
import geemap
import ee
import numpy as np
import cv2
import random
import torch
from torch import nn, optim
import torch.nn.functional as F
import matplotlib.pyplot as plt


device = "cuda" if torch.cuda.is_available() else "cpu"


In [5]:
collection = (
    ee.ImageCollection("COPERNICUS/S2_SR")
    .filterDate("2022-01-01", "2023-12-31")
    .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", 20))
)

def fetch_tile(lon, lat, size=0.005):
    region = ee.Geometry.Rectangle([
        lon-size, lat-size,
        lon+size, lat+size
    ])

    imgs = collection.filterBounds(region)

    if imgs.size().getInfo() == 0:
        return None

    img = imgs.first().select(["B4", "B3", "B2"])

    tile = geemap.ee_to_numpy(img, region=region, scale=10)

    tile = np.clip(tile / 3000 * 255, 0, 255).astype(np.uint8)
    return tile



Attention required for COPERNICUS/S2_SR! You are using a deprecated asset.
To make sure your code keeps working, please update it.
Learn more: https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR



In [6]:
# City centers
CITIES = [
    (80.3319, 26.4499),   # Kanpur
    (77.1025, 28.7041),   # Delhi
    (72.5714, 23.0225),   # Ahmedabad
    (72.8777, 19.0760)    # Mumbai
]

tiles = []
TARGET_TILES = 300

while len(tiles) < TARGET_TILES:

    # pick random city
    lon, lat = random.choice(CITIES)

    # random offset inside city area
    dx = np.random.uniform(-0.18, 0.18)
    dy = np.random.uniform(-0.18, 0.18)

    tile = fetch_tile(lon + dx, lat + dy)

    if tile is not None:
        tiles.append(tile)

print("Collected tiles:", len(tiles))


Collected tiles: 300


In [7]:
def make_lr_hr(img):
    hr = img.copy()
    lr = cv2.resize(hr, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
    return lr, hr

def to_tensor(img):
    return torch.tensor(img/255.0).permute(2,0,1).float().unsqueeze(0).to(device)


In [8]:
class DenseBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.act = nn.LeakyReLU(0.2, inplace=True)

        self.c1 = nn.Conv2d(c, c, 3, 1, 1)
        self.c2 = nn.Conv2d(c*2, c, 3, 1, 1)
        self.c3 = nn.Conv2d(c*3, c, 3, 1, 1)
        self.c4 = nn.Conv2d(c*4, c, 3, 1, 1)
        self.c5 = nn.Conv2d(c*5, c, 3, 1, 1)

    def forward(self, x):
        x1 = self.act(self.c1(x))
        x2 = self.act(self.c2(torch.cat([x, x1], 1)))
        x3 = self.act(self.c3(torch.cat([x, x1, x2], 1)))
        x4 = self.act(self.c4(torch.cat([x, x1, x2, x3], 1)))
        x5 = self.c5(torch.cat([x, x1, x2, x3, x4], 1))
        return x + 0.2 * x5


class RRDB(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.d1 = DenseBlock(c)
        self.d2 = DenseBlock(c)
        self.d3 = DenseBlock(c)

    def forward(self, x):
        return x + 0.2 * self.d3(self.d2(self.d1(x)))


class ESRGAN_Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.start = nn.Conv2d(3, 64, 3, 1, 1)

        self.body = nn.Sequential(*[RRDB(64) for _ in range(16)])

        self.refine = nn.Conv2d(64, 64, 3, 1, 1)

        self.up = nn.Sequential(
            nn.Conv2d(64, 256, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 256, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2)
        )

        self.end = nn.Conv2d(64, 3, 3, 1, 1)

    def forward(self, x):
        f = self.start(x)
        r = self.refine(self.body(f))
        f = f + r
        return self.end(self.up(f))

In [9]:
def psnr(sr, hr):
    mse = torch.mean((sr - hr) ** 2)
    if mse == 0:
        return 100
    return 20 * torch.log10(1.0 / torch.sqrt(mse))


def ssim(sr, hr):
    C1 = 0.01**2
    C2 = 0.03**2

    mu_x = sr.mean()
    mu_y = hr.mean()

    sigma_x = ((sr - mu_x)**2).mean()
    sigma_y = ((hr - mu_y)**2).mean()
    sigma_xy = ((sr - mu_x)*(hr - mu_y)).mean()

    return ((2*mu_x*mu_y + C1)*(2*sigma_xy + C2)) / (
        (mu_x**2 + mu_y**2 + C1)*(sigma_x + sigma_y + C2)
    )


In [10]:
def edge_loss(sr, hr):
    sobel = torch.tensor([[1,0,-1],[2,0,-2],[1,0,-1]],
                         dtype=torch.float32).to(device)
    sobel = sobel.view(1,1,3,3)

    sr_g = F.conv2d(sr[:,0:1], sobel, padding=1)
    hr_g = F.conv2d(hr[:,0:1], sobel, padding=1)

    return torch.mean((sr_g - hr_g)**2)


In [11]:
def random_crop(lr, hr, size=64):
    h, w, _ = lr.shape
    crop = min(size, h, w)

    if crop < 16:
        return None, None

    x = random.randint(0, h-crop)
    y = random.randint(0, w-crop)

    lr_c = lr[x:x+crop, y:y+crop]
    hr_c = hr[x*4:(x+crop)*4, y*4:(y+crop)*4]

    return lr_c, hr_c


In [12]:
def match_size(sr, hr):
    h = min(sr.shape[2], hr.shape[2])
    w = min(sr.shape[3], hr.shape[3])
    return sr[:, :, :h, :w], hr[:, :, :h, :w]


In [None]:
model = ESRGAN_Generator().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.L1Loss()

EPOCHS = 30

for epoch in range(EPOCHS):

    total_psnr = 0
    total_ssim = 0
    count = 0

    for img in tiles:
        lr_img, hr_img = make_lr_hr(img)

        for _ in range(15):
            lr_p, hr_p = random_crop(lr_img, hr_img)
            if lr_p is None:
                continue

            lr = to_tensor(lr_p)
            hr = to_tensor(hr_p)

            sr = model(lr)

            sr, hr = match_size(sr, hr)

            loss = loss_fn(sr, hr) + 0.05 * edge_loss(sr, hr)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                total_psnr += psnr(sr, hr).item()
                total_ssim += ssim(sr, hr).item()
                count += 1

    print(f"Epoch {epoch+1} | PSNR: {total_psnr/count:.2f} dB | SSIM: {total_ssim/count:.4f}")



Epoch 1 | PSNR: 30.00 dB | SSIM: 0.8728
Epoch 2 | PSNR: 31.20 dB | SSIM: 0.9000
Epoch 3 | PSNR: 31.19 dB | SSIM: 0.9007
Epoch 4 | PSNR: 31.37 dB | SSIM: 0.9036
Epoch 5 | PSNR: 31.54 dB | SSIM: 0.9079
Epoch 6 | PSNR: 31.87 dB | SSIM: 0.9143
Epoch 7 | PSNR: 32.21 dB | SSIM: 0.9207
Epoch 8 | PSNR: 32.53 dB | SSIM: 0.9274
Epoch 9 | PSNR: 32.86 dB | SSIM: 0.9336


In [None]:
test_tile = fetch_tile(KANPUR_LON, KANPUR_LAT)
lr_img, hr_img = make_lr_hr(test_tile)

lr = to_tensor(lr_img)

model.eval()
with torch.no_grad():
    sr = model(lr)

sr_img = sr.squeeze().permute(1,2,0).cpu().numpy()
sr_img = (sr_img * 255).clip(0,255).astype(np.uint8)

plt.figure(figsize=(15,5))

plt.subplot(1,3,1)
plt.title("Low Resolution")
plt.imshow(lr_img)
plt.axis("off")

plt.subplot(1,3,2)
plt.title("ESRGAN Output")
plt.imshow(sr_img)
plt.axis("off")

plt.subplot(1,3,3)
plt.title("High Resolution GT")
plt.imshow(hr_img)
plt.axis("off")

plt.show()

sr_t = torch.tensor(sr_img/255.0).permute(2,0,1).unsqueeze(0).to(device)
hr_t = torch.tensor(hr_img/255.0).permute(2,0,1).unsqueeze(0).to(device)

sr_t, hr_t = match_size(sr_t, hr_t)

print("Final PSNR:", psnr(sr_t, hr_t).item())
print("Final SSIM:", ssim(sr_t, hr_t).item())

In [None]:
def bicubic_upscale(lr):
    return cv2.resize(lr, None, fx=4, fy=4, interpolation=cv2.INTER_CUBIC)


In [None]:
DELHI_LON = 77.1025
DELHI_LAT = 28.7041


In [None]:
def evaluate_model_delhi(model, num_tests=150):   # try 100â€“200
    model.eval()

    psnr_sr, ssim_sr = [], []
    psnr_bi, ssim_bi = [], []

    tested = 0

    while tested < num_tests:
        dx = np.random.uniform(-0.25, 0.25)   # wider city coverage
        dy = np.random.uniform(-0.25, 0.25)

        tile = fetch_tile(DELHI_LON + dx, DELHI_LAT + dy)
        if tile is None:
            continue

        lr_img, hr_img = make_lr_hr(tile)

        # ESRGAN
        lr = to_tensor(lr_img)
        with torch.no_grad():
            sr = model(lr)

        sr_img = sr.squeeze().permute(1,2,0).cpu().numpy()
        sr_img = (sr_img * 255).clip(0,255).astype(np.uint8)

        # Bicubic baseline
        bi_img = cv2.resize(lr_img, None, fx=4, fy=4, interpolation=cv2.INTER_CUBIC)

        # tensors
        sr_t = torch.tensor(sr_img/255.).permute(2,0,1).unsqueeze(0).to(device)
        bi_t = torch.tensor(bi_img/255.).permute(2,0,1).unsqueeze(0).to(device)
        hr_t = torch.tensor(hr_img/255.).permute(2,0,1).unsqueeze(0).to(device)

        sr_t, hr_t = match_size(sr_t, hr_t)
        bi_t, hr_t = match_size(bi_t, hr_t)

        psnr_sr.append(psnr(sr_t, hr_t).item())
        ssim_sr.append(ssim(sr_t, hr_t).item())

        psnr_bi.append(psnr(bi_t, hr_t).item())
        ssim_bi.append(ssim(bi_t, hr_t).item())

        tested += 1

    print("\n===== DELHI TEST RESULTS =====")
    print("Samples tested:", tested)

    print("\nBicubic:")
    print("PSNR:", np.mean(psnr_bi))
    print("SSIM:", np.mean(ssim_bi))

    print("\nESRGAN:")
    print("PSNR:", np.mean(psnr_sr))
    print("SSIM:", np.mean(ssim_sr))

    print("\nImprovement:")
    print("PSNR gain:", np.mean(psnr_sr) - np.mean(psnr_bi))
    print("SSIM gain:", np.mean(ssim_sr) - np.mean(ssim_bi))


In [None]:
evaluate_model_delhi(model, num_tests=100)



In [None]:
# Pick a fresh unseen tile
dx = np.random.uniform(-0.3, 0.5)
dy = np.random.uniform(-0.3, 0.5)

test_tile = fetch_tile(KANPUR_LON + dx, KANPUR_LAT + dy)

lr_img, hr_img = make_lr_hr(test_tile)

# ESRGAN output
lr = to_tensor(lr_img)
with torch.no_grad():
    sr = model(lr)

sr_img = sr.squeeze().permute(1,2,0).cpu().numpy()
sr_img = (sr_img * 255).clip(0,255).astype(np.uint8)

# Bicubic baseline
bi_img = cv2.resize(lr_img, None, fx=4, fy=4, interpolation=cv2.INTER_CUBIC)

# Plot comparison
plt.figure(figsize=(12,4))

plt.subplot(1,3,1)
plt.imshow(bi_img)
plt.title("Bicubic")
plt.axis("off")

plt.subplot(1,3,2)
plt.imshow(sr_img)
plt.title("ESRGAN")
plt.axis("off")

plt.subplot(1,3,3)
plt.imshow(hr_img)
plt.title("Ground Truth")
plt.axis("off")

plt.show()


In [None]:
torch.save({
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
}, "esrgan_satellite_2.pth")

print("Model saved successfully!")
