In [21]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
import pyiqa
from sklearn.model_selection import train_test_split

In [22]:

# ── USER CONFIG ────────────────────────────────────
dataset_dir   = "Datasets/"
output_dir    = "G:/MaestriaInformatica/Tesis/Outputs/CNN_Basic/"

LR_DIR        = dataset_dir + "SatImages/perusat_v5_rgb/LR_BICUBIC/X2/"      # LR 256×256
HR_DIR        = dataset_dir + "SatImages/perusat_v5_rgb/HR/"                 # HR 512×512
SUFFIX        = "_x2"
HR_HEIGHT     = 512
HR_WIDTH      = 512
BATCH_SIZE    = 16
MAX_IMAGES    = 0      # 0 = usar todas las parejas en evaluación
VAL_FRAC      = 0.2
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EPOCHS        = 84
LEARNING_RATE = 5e-4
# ─────────────────────────────────────────────────────────────────────────────

# ── EXAMPLES CONFIG ──────────────────────────────────────────────────────────
EXAMPLE_NAMES = [
    "IMG_PER1_20161203152919_ORT_MS_000670_1024-3072.png",
    "IMG_PER1_20161203152919_ORT_MS_000670_2048-512.png",
    "IMG_PER1_20170422154946_ORT_MS_000041_3072-4608.png",
    "IMG_PER1_20170422154946_ORT_MS_000041_1536-2560.png",
    "IMG_PER1_20170422154946_ORT_MS_000041_512-3584.png",
    "IMG_PER1_20161203152919_ORT_MS_000670_4608-4608.png",
    "IMG_PER1_20170422154946_ORT_MS_000659_1536-4608.png",
    "IMG_PER1_20170422154946_ORT_MS_000659_5120-3584.png",
    "IMG_PER1_20170422154946_ORT_MS_001277_3584-4608.png",
    "IMG_PER1_20190320154045_ORT_MS_000041_4608-3072.png",
]
EXAMPLES_DIR  = output_dir + "samples/"  # donde se guardan los 10 SR ejemplos
# ─────────────────────────────────────────────────────────────────────────────

In [23]:


def load_dataset(hr_dir, lr_dir, suffix, hr_size, max_images=100, val_frac=0.1):
    all_hr = [f for f in sorted(os.listdir(hr_dir))
              if f.lower().endswith(('.png','.jpg','.jpeg'))]
    valid_hr = []
    for fn in all_hr:
        with Image.open(os.path.join(hr_dir, fn)) as im:
            if im.size == (hr_size[1], hr_size[0]):
                valid_hr.append(fn)

    total_hr = len(valid_hr)
    if max_images > 0 and total_hr < max_images:
        raise RuntimeError(f"Sólo hay {total_hr} imágenes HR de tamaño {hr_size}, pero pediste max_images={max_images}.")
    if max_images == 0:
        max_images = total_hr

    selected_hr = valid_hr[:max_images]
    lr_names = []
    for hr_fn in selected_hr:
        base, ext = os.path.splitext(hr_fn)
        lr_fn = base + suffix + ext
        lr_path = os.path.join(lr_dir, lr_fn)
        if not os.path.exists(lr_path):
            raise RuntimeError(f"No existe LR para '{hr_fn}' → buscado '{lr_fn}'.")
        with Image.open(lr_path) as im:
            expected = (hr_size[1]//2, hr_size[0]//2)
            if im.size != expected:
                raise RuntimeError(f"LR '{lr_fn}' tiene tamaño {im.size}, pero esperaba {expected}.")
        lr_names.append(lr_fn)

    train_hr, val_hr, train_lr, val_lr = train_test_split(
        selected_hr, lr_names, test_size=val_frac, random_state=42
    )

    to_tensor = transforms.ToTensor()
    def make_dataset(hr_list, lr_list):
        hr_tensors, lr_tensors = [], []
        for hr_fn, lr_fn in zip(hr_list, lr_list):
            hr = Image.open(os.path.join(hr_dir, hr_fn)).convert('RGB')
            lr = Image.open(os.path.join(lr_dir, lr_fn)).convert('RGB')
            hr_tensors.append(to_tensor(hr))
            lr_tensors.append(to_tensor(lr))
        return torch.stack(lr_tensors), torch.stack(hr_tensors)

    lr_train, hr_train = make_dataset(train_hr, train_lr)
    lr_val,   hr_val   = make_dataset(val_hr,   val_lr)

    return TensorDataset(lr_train, hr_train), TensorDataset(lr_val, hr_val)

def evaluate_dataset(generator, data_loader, device):
    generator.eval()
    total_psnr = total_ssim = 0.0
    count = 0

    ssimc_metric = pyiqa.create_metric('ssimc', device=device)
    psnr_metric  = pyiqa.create_metric('psnr',  device=device)

    with torch.no_grad():
        for lr_imgs, hr_imgs in data_loader:
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
            sr_imgs = generator(lr_imgs)
            ssim_val = ssimc_metric(hr_imgs, sr_imgs)[0].item()
            psnr_val = psnr_metric(hr_imgs, sr_imgs)[0].item()

            total_psnr += psnr_val * hr_imgs.size(0)
            total_ssim += ssim_val * hr_imgs.size(0)
            count += hr_imgs.size(0)

    return (total_psnr / count) if count else 0, (total_ssim / count) if count else 0

class SuperResCNN(nn.Module):
    def __init__(self, scale=2, channels=64):
        super().__init__()
        self.scale = scale
        self.conv1 = nn.Conv2d(3, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(channels, 3, kernel_size=3, padding=1)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        up  = F.interpolate(x, scale_factor=self.scale, mode="bicubic", align_corners=False)
        out = self.relu(self.conv1(up))
        out = self.relu(self.conv2(out))
        out = self.relu(self.conv3(out))
        out = self.conv4(out)
        return torch.clamp(out + up, 0.0, 1.0)

In [24]:
if __name__ == "__main__":
    # 1) Carga datasets
    train_ds, val_ds = load_dataset(
        hr_dir     = HR_DIR,
        lr_dir     = LR_DIR,
        suffix     = SUFFIX,
        hr_size    = (HR_HEIGHT, HR_WIDTH),
        max_images = MAX_IMAGES,
        val_frac   = VAL_FRAC
    )
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)

    # 2) Modelo, loss y optimizador
    model = SuperResCNN(scale=2).to(DEVICE)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # 3) Entrenamiento
    for epoch in range(1, EPOCHS+1):
        model.train()
        running_loss = 0.0
        for lr_imgs, hr_imgs in train_loader:
            lr_imgs, hr_imgs = lr_imgs.to(DEVICE), hr_imgs.to(DEVICE)
            optimizer.zero_grad()
            sr_imgs = model(lr_imgs)
            loss = criterion(sr_imgs, hr_imgs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * lr_imgs.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch:02d}/{EPOCHS} — Loss: {epoch_loss:.6f}")

    # 4) Evaluación PSNR/SSIM
    avg_psnr, avg_ssim = evaluate_dataset(model, val_loader, DEVICE)
    print(f"→ PSNR promedio: {avg_psnr:.2f} dB")
    print(f"→ SSIM promedio: {avg_ssim:.4f}")

    # 5) Generar 10 ejemplos CNN-SR
    os.makedirs(EXAMPLES_DIR, exist_ok=True)
    for hr_fn in EXAMPLE_NAMES:
        base, ext = os.path.splitext(hr_fn)
        lr_fn = base + SUFFIX + ext
        lr_path = os.path.join(LR_DIR, lr_fn)

        lr_img    = Image.open(lr_path).convert("RGB")
        lr_tensor = transforms.ToTensor()(lr_img).unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            sr_tensor = model(lr_tensor)

        sr_img   = TF.to_pil_image(sr_tensor.squeeze(0).cpu().clamp(0,1))
        save_fn  = f"{base}_cnn{ext}"
        sr_img.save(os.path.join(EXAMPLES_DIR, save_fn))
        print(f"Guardado ejemplo: {save_fn} → {EXAMPLES_DIR}")

Epoch 01/84 — Loss: 0.000138
Epoch 02/84 — Loss: 0.000097
Epoch 03/84 — Loss: 0.000091
Epoch 04/84 — Loss: 0.000087
Epoch 05/84 — Loss: 0.000086
Epoch 06/84 — Loss: 0.000085
Epoch 07/84 — Loss: 0.000084
Epoch 08/84 — Loss: 0.000083
Epoch 09/84 — Loss: 0.000083
Epoch 10/84 — Loss: 0.000083
Epoch 11/84 — Loss: 0.000082
Epoch 12/84 — Loss: 0.000082
Epoch 13/84 — Loss: 0.000081
Epoch 14/84 — Loss: 0.000081
Epoch 15/84 — Loss: 0.000081
Epoch 16/84 — Loss: 0.000081
Epoch 17/84 — Loss: 0.000080
Epoch 18/84 — Loss: 0.000080
Epoch 19/84 — Loss: 0.000080
Epoch 20/84 — Loss: 0.000080
Epoch 21/84 — Loss: 0.000080
Epoch 22/84 — Loss: 0.000079
Epoch 23/84 — Loss: 0.000079
Epoch 24/84 — Loss: 0.000079
Epoch 25/84 — Loss: 0.000079
Epoch 26/84 — Loss: 0.000079
Epoch 27/84 — Loss: 0.000078
Epoch 28/84 — Loss: 0.000078
Epoch 29/84 — Loss: 0.000078
Epoch 30/84 — Loss: 0.000078
Epoch 31/84 — Loss: 0.000078
Epoch 32/84 — Loss: 0.000078
Epoch 33/84 — Loss: 0.000078
Epoch 34/84 — Loss: 0.000078
Epoch 35/84 — 