In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split

# Custom Dataset Loader to merge images with vegetation layer
class PovertyDataset(Dataset):
    """
    A custom dataset for satellite images with 4 channels: RGB + vegetation.
    images_filenames: List of file paths to images (each has 4 channels).
    targets: List or array of poverty values (floats).
    transform: Optional set of transforms (e.g., scaling, augmentations).
    """
    def __init__(self, images_filenames, targets, transform=None):
        self.images_filenames = images_filenames
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.images_filenames)

    def __getitem__(self, idx):
        img_path = self.images_filenames[idx]
        with Image.open(img_path) as img:
            # Convert to 4 channels, e.g., "RGBA" if your file includes a 4th band
            # If the vegetation channel is separate, you'd have to load it separately 
            # and stack them. This example assumes your file already has 4 channels.
            img = img.convert("RGBA")
            img_np = np.array(img, dtype=np.float32)

        target = self.targets[idx]

        if self.transform:
            img_np = self.transform(img_np)

        return img_np, np.float32(target)


# Define Transforms
class ToTensor:
    """Convert a numpy array (H, W, C) to a PyTorch tensor of shape (C, H, W)."""
    def __call__(self, sample):
        sample_tensor = torch.from_numpy(sample).permute(2, 0, 1)  # (C,H,W)
        return sample_tensor

class MinMaxScale:
    """
    Scale each pixel channel-wise from [min_val, max_val] to [0,1].
    Adjust min_val/max_val to match your data's range.
    """
    def __init__(self, min_val=0.0, max_val=255.0):
        self.min_val = min_val
        self.max_val = max_val

    def __call__(self, sample_tensor):
        return (sample_tensor - self.min_val) / (self.max_val - self.min_val + 1e-8)


# Compose transforms
transform = T.Compose([
    ToTensor(),           
    MinMaxScale(0, 255),  
])


# Train/Test Split + Dataloaders
images_filenames = [
    # "path/to/image1.png",
    # "path/to/image2.png",
    # ...
]
targets = [
    # poverty_value1,
    # poverty_value2,
    # ...
]

# Split data
train_files, test_files, train_targets, test_targets = train_test_split(
    images_filenames, targets, test_size=0.2, random_state=42
)

# Create Datasets
train_dataset = PovertyDataset(train_files, train_targets, transform=transform)
test_dataset  = PovertyDataset(test_files,  test_targets,  transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=8, shuffle=False)


# Definition of NN
# ResNet with a MLP regression Head
class ResNetRegressor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        # Load a pretrained ResNet (e.g., ResNet18)
        self.backbone = models.resnet18(pretrained=pretrained)
        
        # Modify the first conv layer to accept 4 channels instead of 3
        old_weights = self.backbone.conv1.weight.data
        new_conv = nn.Conv2d(
            in_channels=4, out_channels=64,
            kernel_size=7, stride=2, padding=3, bias=False
        )
        # Copy the original RGB weights
        new_conv.weight.data[:, :3, :, :] = old_weights
        # Initialize the 4th channel
        nn.init.xavier_normal_(new_conv.weight.data[:, 3:, :, :])
        self.backbone.conv1 = new_conv

        # Remove the original FC (classifier) and replace with Identity
        self.backbone.fc = nn.Identity()

        # A small MLP to map the 512-dim features to a single scalar
        # (ResNet18 typically ends with a 512-dim feature after global pooling)
        self.mlp = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1)  # single output for regression
        )

    def forward(self, x):
        # Extract features via the CNN backbone
        features = self.backbone(x)  # shape [batch_size, 512]
        # Pass features through the MLP
        out = self.mlp(features)
        return out

In [None]:
# Instantiate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNetRegressor(pretrained=True)
model.to(device)

# Define loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


# Training Loop
num_epochs = 5
train_losses, test_losses = [], []

for epoch in range(num_epochs):
    # ---- TRAIN ----
    model.train()
    running_train_loss = 0.0
    for images, targets in train_loader:
        images = images.to(device)
        targets = targets.to(device).view(-1, 1)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item() * images.size(0)

    epoch_train_loss = running_train_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)

    # ---- EVAL ----
    model.eval()
    running_test_loss = 0.0
    with torch.no_grad():
        for images, targets in test_loader:
            images = images.to(device)
            targets = targets.to(device).view(-1, 1)

            outputs = model(images)
            loss = criterion(outputs, targets)
            running_test_loss += loss.item() * images.size(0)

    epoch_test_loss = running_test_loss / len(test_loader.dataset)
    test_losses.append(epoch_test_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {epoch_train_loss:.4f} | Test Loss: {epoch_test_loss:.4f}")



plt.figure()
plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Training vs. Test Loss")
plt.legend()
plt.show()