In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import rasterio
import numpy as np
import torch.nn as nn
import torch.optim as optim
from skimage.transform import resize
from tqdm import tqdm
import pandas as pd
import torchvision.transforms as transforms

In [5]:
def train_model(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, targets in tqdm(dataloader):
        inputs  = inputs.float().to(device)
        targets = targets.float().to(device)

        inputs[torch.isnan(inputs)] = 0
        targets[torch.isnan(targets)] = 0

        assert not torch.isnan(inputs).any(), "Input contains NaN values"
        assert not torch.isnan(targets).any(), "Target contains NaN values"
        
        optimizer.zero_grad()
        outputs_up = model(inputs)

        loss = criterion(outputs_up, targets).to(device)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

def validate_model(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader):
            inputs = inputs.float().to(device)
            targets =  targets.float().to(device)
            inputs[torch.isnan(inputs)] = 0
            targets[torch.isnan(targets)] = 0

            outputs = model(inputs)

            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss


In [6]:
import torch
import torch.nn as nn

class IEEB(nn.Module):
    """Information Extraction and Enhancement Block"""
    def __init__(self, in_channels, out_channels):
        super(IEEB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.act = nn.PReLU()
        
    def forward(self, x):
        out1 = self.act(self.conv1(x))
        out2 = self.act(self.conv2(out1))
        out3 = self.act(self.conv3(out2))
        out4 = self.act(self.conv4(out3))
        out5 = self.conv5(out4)
        return out5 + x

class RB(nn.Module):
    """Reconstruction Block"""
    def __init__(self, channels):
        super(RB, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.act = nn.PReLU()
        
    def forward(self, x):
        out = self.act(self.conv1(x))
        out = self.conv2(out)
        return out + x

class IRB(nn.Module):
    """Information Refinement Block"""
    def __init__(self, channels):
        super(IRB, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.act = nn.PReLU()
        
    def forward(self, x):
        out1 = self.act(self.conv1(x))
        out2 = self.act(self.conv2(out1))
        out3 = self.conv3(out2)
        return out3 + x

class LESRCNN(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_channels=64, scale_factor=2):
        super(LESRCNN, self).__init__()
        
        # 初始特征提取
        self.init_feature = nn.Sequential(
            nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1),
            nn.PReLU()
        )
        
        # 信息提取和增强块
        self.ieeb = IEEB(num_channels, num_channels)
        
        # 重建块
        self.rb = RB(num_channels)
        
        # 信息精炼块
        self.irb = IRB(num_channels)
        
        # 上采样块
        self.upsampler = nn.Sequential(
            nn.Conv2d(num_channels, num_channels * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor),
            nn.PReLU()
        )
        
        # 最终重建
        self.final_conv = nn.Conv2d(num_channels, out_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        # 初始特征提取
        init_feat = self.init_feature(x)
        
        # 信息提取和增强
        ieeb_out = self.ieeb(init_feat)
        
        # 重建
        rb_out = self.rb(ieeb_out)
        
        # 信息精炼
        irb_out = self.irb(rb_out)
        
        # 上采样
        up_out = self.upsampler(irb_out)
        
        # 最终重建
        out = self.final_conv(up_out)
        
        return out

Input shape: torch.Size([1, 1, 40, 52])
Output shape: torch.Size([1, 1, 400, 520])
Number of parameters: 4063302


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = LESRCNN(in_channels=1, out_channels=1, num_channels=64, scale_factor=10).to(device)

learning_rate = 0.0001
criterion = nn.L1Loss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-4)

num_epochs = 100
best_loss = float('inf')

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Model is training on: {next(model.parameters()).device}')
    train_loss = train_model(model, dataloader_train, criterion, optimizer, device)
    val_loss = validate_model(model, dataloader_val, criterion, device)
    print(f'Train Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f}')
    
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), 'pathpath')
        print('Model saved!')
