In [18]:
import os
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader


class CoralBleachingDataset(Dataset):
    def __init__(self, root_dir, metadata_file, transform=None,
                 image_col='name', label_col='label', location_col='location', date_col='date'):
        self.root_dir = root_dir
        self.metadata = pd.read_csv(metadata_file)
        self.transform = transform
        self.image_col = image_col
        self.label_col = label_col
        self.location_col = location_col
        self.date_col = date_col

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        # reading
        img_name = self.metadata.iloc[idx][self.image_col]
        label = self.metadata.iloc[idx][self.label_col]
        location = self.metadata.iloc[idx][self.location_col]
        date = self.metadata.iloc[idx][self.date_col]

        img_path = os.path.join(self.root_dir, label, img_name)
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image file not found: {img_path}")

        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # Returns a dictionary containing images, labels, locations and dates
        return {
            'image': image,
            'label': label,
            'location': location,
            'date': date
        }

def split_dataset(metadata_file, test_size=0.2, random_state=42):
    metadata = pd.read_csv(metadata_file)
    train_data, val_data = train_test_split(metadata, test_size=test_size, random_state=random_state)
    train_data.to_csv('train_metadata.csv', index=False)
    val_data.to_csv('val_metadata.csv', index=False)
    return 'train_metadata.csv', 'val_metadata.csv'

def collate_fn(batch):
    batch = [sample for sample in batch if sample is not None]
    return {
        'images': torch.stack([item['image'] for item in batch]),
        'labels': [item['label'] for item in batch],
        'locations': [item['location'] for item in batch],
        'dates': [item['date'] for item in batch]
    }

# Set the path
root_dir = '../data/outputs/images/Altieri Biscayne Bay'
metadata_file = os.path.join(root_dir, 'metadata.csv')

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Instantiate the dataset
coral_dataset = CoralBleachingDataset(root_dir=root_dir, metadata_file=metadata_file, transform=transform)

sample = coral_dataset[0]
print(sample['image'].shape, sample['label'], sample['location'], sample['date'])

torch.Size([3, 224, 224]) healthy Biscayne Bay 2022-05-01


In [19]:

train_metadata_file, val_metadata_file = split_dataset(metadata_file, test_size=0.2)

# Instantiate the dataset
train_dataset = CoralBleachingDataset(root_dir=root_dir, metadata_file=train_metadata_file, transform=transform)
val_dataset = CoralBleachingDataset(root_dir=root_dir, metadata_file=val_metadata_file, transform=transform)

# Use DataLoader to load data
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

for batch in train_loader:
    print("Batch images shape:", batch['images'].shape)
    print("Batch labels:", batch['labels'])
    print("Batch locations:", batch['locations'])
    print("Batch dates:", batch['dates'])
    break

Batch images shape: torch.Size([23, 3, 224, 224])
Batch labels: ['dead', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy', 'bleached', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy', 'dead', 'bleached', 'dead', 'bleached', 'dead', 'bleached']
Batch locations: ['Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay', 'Biscayne Bay']
Batch dates: ['2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-01', '2022-05-0