In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
import os 

In [3]:
from torch.utils.data import Dataset

class WaterAccessDataset(Dataset):

    # constructor
    def __init__(self, csv_path, image_dir, transform=None):
        self.data = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.transform = transform

    # len(dataset)
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        
        # get single row in dataframe and convert score to tensor
        row = self.data.iloc[index]
        tile_id = row['tile_id']
        label = torch.tensor(row['score'], dtype=torch.float32)

        # load and process image
        img_path = os.path.join(self.image_dir, f"sentinel2_{tile_id}.png")
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # remove non feature columns
        tab = row.drop(['tile_id', 'score', 'system:index', 'category_bonus', 
                        'distance_weighted_score', 'norm_distance_weighted', 
                        'num_sources', 'pressure_score', 'random', 'water_point_population', 
                        'water_source_category', '.geo']).values.astype('float32')
        
        tab = torch.tensor(tab)    

        return (image, tab), label

In [4]:
# --- image transformations ---

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,0.225])
])  

In [7]:
dataset = WaterAccessDataset(csv_path='tile_features.csv', 
                             image_dir='earth_engine/converted_png', 
                             transform = transform)

In [8]:
# --- test sample ---

sample = dataset[0]
(image, tabular), label = sample

print("image shape:", image.shape) 
print("tabular shape:", tabular.shape)  
print("label (the score):", label)

image shape: torch.Size([3, 224, 224])
tabular shape: torch.Size([6])
label (the score): tensor(1.1505)
