In [2]:
import pandas as pd
import loguru
from data.dataset import Sentinel2Dataset
from data.loader import define_loaders
from utils.utils import load_config
import torch
import torch.nn as nn
from tqdm import tqdm

from model_zoo.models import define_model

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config = load_config(config_path="cfg/config.yaml")
BASE_DIR = config["DATASET"]["base_dir"]
VERSION = config['DATASET']['version']
BATCH_SIZE = config['TRAINING']['batch_size']
NUM_WORKERS = config['TRAINING']['num_workers']
RESIZE = config['TRAINING']['resize']
LEARNING_RATE = config['TRAINING']['learning_rate']
train_path = f"{BASE_DIR}/{VERSION}/train_path.csv"
val_path = f"{BASE_DIR}/{VERSION}/val_path.csv"
test_path = f"{BASE_DIR}/{VERSION}/test_path.csv"

In [4]:
df_train = pd.read_csv(train_path)
df_val = pd.read_csv(val_path)
df_test = pd.read_csv(test_path)

In [5]:
train_dataset = Sentinel2Dataset(df_path=df_train,
                                 train=True, augmentation=False,
                                 img_size=RESIZE)

val_dataset = Sentinel2Dataset(df_path=df_val,
                               train=False, augmentation=False,
                               img_size=RESIZE)

test_dataset = Sentinel2Dataset(df_path=df_test,
                                 train=True, augmentation=False,
                                 img_size=RESIZE)

train_loader, val_loader = define_loaders(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        train=True,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
    )

test_loader =  define_loaders(
        train_dataset=test_dataset,
        val_dataset=None,
        train=False,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
    )


In [6]:
from utils.torch import load_model_weights


model = define_model(name="Unet", encoder_name="resnet34",
                     in_channel=3, out_channels=3, activation=None)

In [7]:
weights_path ="/home/ubuntu/project/sentinel-2-ai-processor/results/checkpoints/best_model.pth"
model = load_model_weights(model=model, filename=weights_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


 -> Loading encoder weights from /home/ubuntu/project/sentinel-2-ai-processor/results/checkpoints/best_model.pth



In [8]:
metrics_dict = {
    'mse': [],
}

In [None]:
model.eval()
test_loss = 0.0  # Changed variable name from val_loss to test_loss
criterion = nn.MSELoss()
with torch.no_grad():
    with tqdm(total=len(test_dataset), colour='#f4d160') as t:
        t.set_description('testing')  # Changed from 'validation' to 'testing'

        for batch_idx, (x_data, y_data) in enumerate(test_loader):
            x_data = x_data.to(device)
            y_data = y_data.to(device)
            valid_mask = (y_data >= 0)

            # Forward pass
            outputs = model(x_data)
            loss = criterion(outputs[valid_mask], y_data[valid_mask])

            # Update statistics
            batch_loss = loss.item()
            test_loss += batch_loss

            # Update progress bar
            t.set_postfix(loss=batch_loss)
            t.update(x_data.size(0))

avg_test_loss = test_loss / len(test_loader)
metrics_dict['test_loss'] = avg_test_loss  # You might need to update this line
# or metrics_dict['mse'].append(avg_test_loss)
print(f'Test Loss: {avg_test_loss}')

In [None]:
# Validation phase

metrics_dict = {
    'val_loss': [],
    'val_psrn':[]

}
model.eval()
val_loss = 0.0
criterion = nn.MSELoss()
with torch.no_grad():
    with tqdm(total=len(val_dataset), ncols=100, colour='#f4d160') as t:
        t.set_description('validation')

        for batch_idx, (x_data, y_data) in enumerate(val_loader):
            x_data = x_data.to(device)
            y_data = y_data.to(device)
            valid_mask = (y_data >= 0)


            # Forward pass
            outputs = model(x_data)
            loss = criterion(outputs[valid_mask], y_data[valid_mask] )


            # Update statistics
            batch_loss = loss.item()
            val_loss += batch_loss

            metric = PeakSignalNoiseRatio()
            metric.update(outputs[valid_mask], y_data[valid_mask])
            metrics_dict['val_psrn'].append(float(metric.compute().cpu().numpy()))


            # Update progress bar
            t.set_postfix(loss=batch_loss)
            t.update(x_data.size(0))

avg_val_loss = val_loss / len(val_loader)
metrics_dict['val_loss'].append(avg_val_loss)
metrics_dict['val_psrn'] = np.sum(metrics_dict['val_psrn'])/len(metrics_dict['val_psrn'])


In [None]:
metrics_dict

In [None]:
## Test PSNR


torcheval.metrics.image.psnr.PeakSignalNoiseRatio

In [None]:
metrics_dict = {
    'val_loss': [],
    'val_psrn':[],


}

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchmetrics.image import PeakSignalNoiseRatio

# Assume these are already defined:
# - model: Your model
# - val_loader: DataLoader for your validation dataset
# - device: torch.device("cuda") or torch.device("cpu")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.MSELoss()

# We first assume that the output has the shape (N, C, H, W).
# For example, if you expect 3 channels:
num_channels = 3  # or set to the number of channels in your output

# Create one metric instance per channel for both libraries and move them to the correct device.
psnr_channels = [
    PeakSignalNoiseRatio(data_range=1.0).to(device) for _ in range(num_channels)
]


total_val_loss = 0.0

model.eval()
with torch.no_grad():
    with tqdm(total=len(val_loader), ncols=100, colour='#f4d160') as t:
        t.set_description('validation')
        for batch_idx, (x_data, y_data) in enumerate(val_loader):
            # Move inputs and targets to the device
            x_data = x_data.to(device)
            y_data = y_data.to(device)

            # Forward pass
            outputs = model(x_data)  # Expected shape: (N, C, H, W)

            # Apply a valid mask if needed. Here we assume that each pixel in each channel
            # that is < 0 in the ground truth should be ignored.
            # We'll update the metrics for each channel independently.
            for c in range(num_channels):
                # Extract channel c for both outputs and targets.
                outputs_c = outputs[:, c, :, :]
                y_c = y_data[:, c, :, :]

                # Create channel-wise valid mask. This mask is True for valid pixels.
                valid_mask_c = (y_c >= 0)

                # Select only the valid pixels.
                outputs_valid_c = outputs_c[valid_mask_c]
                y_valid_c = y_c[valid_mask_c]

                # Update the metrics for this channel.
                psnr_channels[c].update(outputs_valid_c, y_valid_c)

            # (Optional) Compute loss for your purposes. Here, for simplicity, we compute the loss
            # on all channels combined using the valid mask across the whole target.
            # Adjust this if you want a channel-wise loss.
            # Here, we create a combined valid mask over all channels.
            valid_mask_all = (y_data >= 0)
            loss = criterion(outputs[valid_mask_all], y_data[valid_mask_all])
            batch_loss = loss.item()
            total_val_loss += batch_loss

            t.set_postfix(loss=batch_loss)
            t.update(x_data.size(0))

avg_val_loss = total_val_loss / len(val_loader)
print(f"Average validation loss: {avg_val_loss}")

# Compute and print final PSNR values for each channel.
for c in range(num_channels):
    final_psnr_torcheval = psnr_torcheval_channels[c].compute().item()
    final_psnr_torchmetrics = psnr_torchmetrics_channels[c].compute().item()
    print(f"Channel {c}: PSNR (torcheval): {final_psnr_torcheval:.4f}, PSNR (torchmetrics): {final_psnr_torchmetrics:.4f}")

validation: : 800it [00:20, 38.72it/s, loss=0.00171]                                                ?, ?it/s, loss=0.0128][0m

Average validation loss: 0.0059565612766891715
Channel 0: PSNR (torcheval): 22.2952, PSNR (torchmetrics): 22.2952
Channel 1: PSNR (torcheval): 22.2825, PSNR (torchmetrics): 22.2825
Channel 2: PSNR (torcheval): 22.1734, PSNR (torchmetrics): 22.1734





In [28]:
psnr_torcheval_channels

[<torcheval.metrics.image.psnr.PeakSignalNoiseRatio at 0x78a17c31e600>,
 <torcheval.metrics.image.psnr.PeakSignalNoiseRatio at 0x78a04b3d0620>,
 <torcheval.metrics.image.psnr.PeakSignalNoiseRatio at 0x78a16b835130>]