In [4]:
import pandas as pd
import loguru
from data.dataset import Sentinel2Dataset
from data.loader import define_loaders
from utils.utils import load_config
import torch
import torch.nn as nn
from tqdm import tqdm

from model_zoo.models import define_model

In [6]:
config = load_config(config_path="cfg/config.yaml")
BASE_DIR = config["DATASET"]["base_dir"]
VERSION = config['DATASET']['version']
BATCH_SIZE = config['TRAINING']['batch_size']
NUM_WORKERS = config['TRAINING']['num_workers']
RESIZE = config['TRAINING']['resize']
LEARNING_RATE = config['TRAINING']['learning_rate']
train_path = f"{BASE_DIR}/{VERSION}/train_path.csv"
val_path = f"{BASE_DIR}/{VERSION}/val_path.csv"
test_path = f"{BASE_DIR}/{VERSION}/test_path.csv"

In [7]:
df_train = pd.read_csv(train_path)
df_val = pd.read_csv(val_path)
df_test = pd.read_csv(test_path)

In [11]:
train_dataset = Sentinel2Dataset(df_path=df_train,
                                 train=True, augmentation=False,
                                 img_size=RESIZE)

val_dataset = Sentinel2Dataset(df_path=df_val,
                               train=False, augmentation=False,
                               img_size=RESIZE)

test_dataset = Sentinel2Dataset(df_path=df_test,
                                 train=True, augmentation=False,
                                 img_size=RESIZE)

train_loader, val_loader = define_loaders(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        train=True,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
    )

test_loader =  define_loaders(
        train_dataset=test_dataset,
        val_dataset=None,
        train=False,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
    )


In [13]:
from utils.torch import load_model_weights


model = define_model(name="Unet", encoder_name="resnet34",
                     in_channel=3, out_channels=3, activation=None)

In [17]:
weights_path ="/home/ubuntu/project/sentinel-2-ai-processor/src/checkpoints/best_model.pth"

model = load_model_weights(model=model, filename=weights_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


 -> Loading encoder weights from /home/ubuntu/project/sentinel-2-ai-processor/src/checkpoints/best_model.pth



In [25]:
len(test_loader)

12

In [27]:
metrics_dict = {
    'mse': [],
}

In [None]:
model.eval()
test_loss = 0.0  # Changed variable name from val_loss to test_loss
criterion = nn.MSELoss()
with torch.no_grad():
    with tqdm(total=len(test_dataset), colour='#f4d160') as t:
        t.set_description('testing')  # Changed from 'validation' to 'testing'

        for batch_idx, (x_data, y_data) in enumerate(test_loader):
            x_data = x_data.to(device)
            y_data = y_data.to(device)
            valid_mask = (y_data >= 0)

            # Forward pass
            outputs = model(x_data)
            loss = criterion(outputs[valid_mask], y_data[valid_mask])

            # Update statistics
            batch_loss = loss.item()
            test_loss += batch_loss

            # Update progress bar
            t.set_postfix(loss=batch_loss)
            t.update(x_data.size(0))

avg_test_loss = test_loss / len(test_loader)
metrics_dict['test_loss'] = avg_test_loss  # You might need to update this line
# or metrics_dict['mse'].append(avg_test_loss)
print(f'Test Loss: {avg_test_loss}')


testing:  96%|[38;2;244;209;96m█████████▌[0m| 192/200 [00:07<00:00, 24.02it/s, loss=0.0128] 

Test Loss: 0.007088003951745729





In [None]:
# # Save metrics
# import matplotlib.pyplot as plt
# import numpy as np

# # Plot loss curves
# plt.figure(figsize=(10, 6))
# plt.plot(metrics_dict['train_loss'], label='Train Loss')
# plt.plot(metrics_dict['val_loss'], label='Validation Loss')
# plt.xlabel('Epochs')
# plt.ylabel('Loss (MSE)')
# plt.title('Training and Validation Loss')
# plt.legend()
# plt.savefig(f"{save_path}/loss_curves.png")
# logger.info(f"Loss curves saved to {save_path}/loss_curves.png")

# # Optionally, test the model on a few validation samples
# model.eval()
# with torch.no_grad():
#     for i, (x_data, y_data) in enumerate(val_loader):
#         if i >= 12:
#             break

#         x_data = x_data.to(device)
#         y_data = y_data.to(device)

#         output = model(x_data)

#         # Convert to numpy for visualization
#         x_np = x_data.cpu().numpy()[0].transpose(1, 2, 0)  # First image in batch, CHW to HWC
#         y_np = y_data.cpu().numpy()[0].transpose(1, 2, 0)
#         pred_np = output.cpu().numpy()[0].transpose(1, 2, 0)
#         # Clip values to valid range for visualization
#         x_np = np.clip(x_np, 0, 1)
#         y_np = np.clip(y_np, 0, 1)
#         pred_np = np.clip(pred_np, 0, 1)

#         # Plot and save
#         fig, axs = plt.subplots(1, 3, figsize=(15, 5))
#         axs[0].imshow(x_np)
#         axs[0].set_title('L1C Input')
#         axs[1].imshow(pred_np)
#         axs[1].set_title('Model Output')
#         axs[2].imshow(y_np)
#         axs[2].set_title('L2A Ground Truth')

#         plt.savefig(f"{save_path}/sample_{i}_prediction.png")
#         plt.close()

# logger.info("Testing completed. Sample predictions saved.")
