In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import rasterio
import numpy as np
import torch.nn as nn
import torch.optim as optim
from skimage.transform import resize
from tqdm import tqdm
import pandas as pd
import torchvision.transforms as transforms

In [5]:
def train_model(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, targets in tqdm(dataloader):
        inputs  = inputs.float().to(device)
        targets = targets.float().to(device)

        inputs[torch.isnan(inputs)] = 0
        targets[torch.isnan(targets)] = 0

        assert not torch.isnan(inputs).any(), "Input contains NaN values"
        assert not torch.isnan(targets).any(), "Target contains NaN values"
        
        optimizer.zero_grad()
        outputs_up = model(inputs)

        loss = criterion(outputs_up, targets).to(device)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

def validate_model(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader):
            inputs = inputs.float().to(device)
            targets =  targets.float().to(device)
            inputs[torch.isnan(inputs)] = 0
            targets[torch.isnan(targets)] = 0

            outputs = model(inputs)

            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss


In [6]:
class FSRCNN(nn.Module):
    def __init__(self):
        super(FSRCNN, self).__init__()
        # Feature Extraction
        self.conv1 = nn.Conv2d(1, 56, kernel_size=5, padding=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(56, 12, kernel_size=1)
        self.relu2 = nn.ReLU()

        self.conv21 = nn.Conv2d(12, 12, kernel_size=1)
        self.relu21 = nn.ReLU()
        self.conv22 = nn.Conv2d(12, 12, kernel_size=1)
        self.relu22 = nn.ReLU()

        self.conv3 = nn.Conv2d(12, 56, kernel_size=1)
        self.relu3 = nn.ReLU()
        
        self.deconv = nn.ConvTranspose2d(56, 56, kernel_size=9, stride=5, padding=4, output_padding=4)
        self.relu4 = nn.ReLU()
        self.deconv1 = nn.ConvTranspose2d(56, 1, kernel_size=9, stride=2, padding=4, output_padding=1)
    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = self.relu21(self.conv21(x))
        x = self.relu22(self.conv22(x))
        x = self.relu3(self.conv3(x))
        x = self.relu4(self.deconv(x))
        x = self.deconv1(x)
        return x

In [14]:
model = FSRCNN().to(device)

learning_rate = 0.0001
criterion = nn.L1Loss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-4)

num_epochs = 100
best_loss = float('inf')


for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Model is training on: {next(model.parameters()).device}')
    train_loss = train_model(model, dataloader_train, criterion, optimizer, device)
    val_loss = validate_model(model, dataloader_val, criterion, device)
    print(f'Train Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f}')
    
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), 'path')
        print('Model saved!')


Epoch 1/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.37it/s]
100%|██████████| 23/23 [00:05<00:00,  4.17it/s]


Train Loss: 264.5793 | Validation Loss: 261.7564
Model saved!
Epoch 2/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.19it/s]
100%|██████████| 23/23 [00:05<00:00,  4.03it/s]


Train Loss: 254.1192 | Validation Loss: 241.6119
Model saved!
Epoch 3/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.13it/s]
100%|██████████| 23/23 [00:05<00:00,  4.24it/s]


Train Loss: 214.9369 | Validation Loss: 172.1268
Model saved!
Epoch 4/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.11it/s]
100%|██████████| 23/23 [00:05<00:00,  4.04it/s]


Train Loss: 102.1652 | Validation Loss: 48.9492
Model saved!
Epoch 5/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.18it/s]
100%|██████████| 23/23 [00:05<00:00,  4.17it/s]


Train Loss: 38.1238 | Validation Loss: 27.1708
Model saved!
Epoch 6/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.19it/s]
100%|██████████| 23/23 [00:05<00:00,  4.22it/s]


Train Loss: 23.6758 | Validation Loss: 20.3016
Model saved!
Epoch 7/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.12it/s]
100%|██████████| 23/23 [00:05<00:00,  4.01it/s]


Train Loss: 18.6244 | Validation Loss: 17.1023
Model saved!
Epoch 8/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.30it/s]
100%|██████████| 23/23 [00:05<00:00,  4.21it/s]


Train Loss: 16.0735 | Validation Loss: 15.0859
Model saved!
Epoch 9/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.30it/s]
100%|██████████| 23/23 [00:05<00:00,  4.12it/s]


Train Loss: 14.6218 | Validation Loss: 14.1232
Model saved!
Epoch 10/100
Model is training on: cuda:0


100%|██████████| 23/23 [00:05<00:00,  4.24it/s]
100%|██████████| 23/23 [00:05<00:00,  4.06it/s]


Train Loss: 13.7655 | Validation Loss: 13.3681
Model saved!
Epoch 11/100
Model is training on: cuda:0


 52%|█████▏    | 12/23 [00:03<00:02,  3.87it/s]


KeyboardInterrupt: 