In [3]:
import os
from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, ConcatDataset


class CoralBleachingDataset(Dataset):
    def __init__(self, root_dir, metadata_file, transform=None,
                 image_col='name', label_col='label', location_col='location', date_col='date', watch_loca= 'CoralReefWatch location', tempe= 'SST@90th_HS'  ):
        self.root_dir = root_dir
        self.metadata = pd.read_csv(metadata_file)
        self.transform = transform
        self.image_col = image_col
        self.label_col = label_col
        self.location_col = location_col
        self.watch_loca = watch_loca
        self.tempe = tempe
        self.date_col = date_col

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        # reading
        img_name = self.metadata.iloc[idx][self.image_col]
        label = self.metadata.iloc[idx][self.label_col]
        location = self.metadata.iloc[idx][self.location_col]
        date = self.metadata.iloc[idx][self.date_col]
        tempe = self.metadata.iloc[idx][self.tempe]

        img_path = os.path.join(self.root_dir, label, img_name)
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image file not found: {img_path}")

        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # Returns a dictionary containing images, labels, locations and dates
        return {
            'image': image,
            'label': label,
            'location': location,
            'date': date,
            'SST@90th_HS': tempe
        }

def split_dataset(metadata_file, test_size=0.2, random_state=42):
    metadata = pd.read_csv(metadata_file)
    train_data, val_data = train_test_split(metadata, test_size=test_size, random_state=random_state)
    train_data.to_csv('train_metadata.csv', index=False)
    val_data.to_csv('val_metadata.csv', index=False)
    return 'train_metadata.csv', 'val_metadata.csv'

def collate_fn(batch):
    batch = [sample for sample in batch if sample is not None]
    return {
        'images': torch.stack([item['image'] for item in batch]),
        'labels': [item['label'] for item in batch],
        'locations': [item['location'] for item in batch],
        'dates': [item['date'] for item in batch],
        'SST@90th_HS': [item['SST@90th_HS'] for item in batch]
    }

# Set the path
root_dir = '../data/outputs/images/Curacao Coral Reef Assessment 2023 CUR'
metadata_file = os.path.join(root_dir, 'metadata.csv')

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Instantiate the dataset
coral_dataset = CoralBleachingDataset(root_dir=root_dir, metadata_file=metadata_file, transform=transform)

sample = coral_dataset[0]
print(sample['image'].shape, sample['label'], sample['location'], sample['date'])

train_metadata_file, val_metadata_file = split_dataset(metadata_file, test_size=0.2)

# Instantiate the dataset
train_dataset = CoralBleachingDataset(root_dir=root_dir, metadata_file=train_metadata_file, transform=transform)
val_dataset = CoralBleachingDataset(root_dir=root_dir, metadata_file=val_metadata_file, transform=transform)

# Use DataLoader to load data
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

for batch in train_loader:
    print("Batch images shape:", batch['images'].shape)
    print("Batch labels:", batch['labels'])
    print("Batch locations:", batch['locations'])
    print("Batch dates:", batch['dates'])
    print("Batch dates:", batch['SST@90th_HS'])
    break

torch.Size([3, 224, 224]) healthy Curacao 2023-10-26
Batch images shape: torch.Size([32, 3, 224, 224])
Batch labels: ['bleached', 'healthy', 'bleached', 'healthy', 'healthy', 'healthy', 'bleached', 'healthy', 'bleached', 'bleached', 'healthy', 'healthy', 'healthy', 'bleached', 'bleached', 'healthy', 'healthy', 'healthy', 'healthy', 'bleached', 'healthy', 'healthy', 'healthy', 'healthy', 'bleached', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy', 'healthy']
Batch locations: ['Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao', 'Curacao']
Batch dates: ['2023-10-27', '2023-10-23', '2023-10-29', '2023-10-23', '2023-10-27', '2023-10-17', '2023-10-21', '2023-10-23', '2023-10-27', '2023-10-27

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 28 * 28, 128)  # 224x224 Image shrunk to 28x28 after 3 poolings
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate the model, define the loss function and optimiser
model = SimpleCNN(num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [7]:
from tqdm import tqdm

num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for epoch in range(num_epochs):
    print(f"Epoch [{epoch+1}/{num_epochs}] - Training started")
    model.train()
    running_loss = 0.0

    # Packaging train_loader with tqdm
    for i, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}")):
        images = batch['images'].to(device)
        labels = torch.tensor([1 if label == 'healthy' else 0 for label in batch['labels']], dtype=torch.long).to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}] - Training loss: {running_loss/len(train_loader):.4f}")

    # Verification
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            images = batch['images'].to(device)
            labels = torch.tensor([1 if label == 'healthy' else 0 for label in batch['labels']], dtype=torch.long).to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}] - Validation Accuracy: {val_accuracy:.2f}%")

Training Epoch 1/10: 100%|██████████| 39/39 [03:19<00:00,  5.12s/it]


Epoch [1/10] - Training loss: 0.6106


Validation: 100%|██████████| 10/10 [00:45<00:00,  4.53s/it]


Epoch [1/10] - Validation Accuracy: 70.92%
Epoch [2/10] - Training started


Training Epoch 2/10: 100%|██████████| 39/39 [03:23<00:00,  5.23s/it]


Epoch [2/10] - Training loss: 0.6015


Validation: 100%|██████████| 10/10 [00:49<00:00,  4.91s/it]


Epoch [2/10] - Validation Accuracy: 70.92%
Epoch [3/10] - Training started


Training Epoch 3/10: 100%|██████████| 39/39 [03:22<00:00,  5.19s/it]


Epoch [3/10] - Training loss: 0.5811


Validation: 100%|██████████| 10/10 [00:48<00:00,  4.84s/it]


Epoch [3/10] - Validation Accuracy: 70.59%
Epoch [4/10] - Training started


Training Epoch 4/10: 100%|██████████| 39/39 [03:22<00:00,  5.20s/it]


Epoch [4/10] - Training loss: 0.5403


Validation: 100%|██████████| 10/10 [00:48<00:00,  4.84s/it]


Epoch [4/10] - Validation Accuracy: 70.92%
Epoch [5/10] - Training started


Training Epoch 5/10: 100%|██████████| 39/39 [03:30<00:00,  5.39s/it]


Epoch [5/10] - Training loss: 0.5011


Validation: 100%|██████████| 10/10 [00:54<00:00,  5.46s/it]


Epoch [5/10] - Validation Accuracy: 70.92%
Epoch [6/10] - Training started


Training Epoch 6/10: 100%|██████████| 39/39 [03:47<00:00,  5.84s/it]


Epoch [6/10] - Training loss: 0.3613


Validation: 100%|██████████| 10/10 [00:54<00:00,  5.41s/it]


Epoch [6/10] - Validation Accuracy: 70.92%
Epoch [7/10] - Training started


Training Epoch 7/10: 100%|██████████| 39/39 [03:42<00:00,  5.69s/it]


Epoch [7/10] - Training loss: 0.2598


Validation: 100%|██████████| 10/10 [00:55<00:00,  5.53s/it]


Epoch [7/10] - Validation Accuracy: 69.61%
Epoch [8/10] - Training started


Training Epoch 8/10: 100%|██████████| 39/39 [03:45<00:00,  5.79s/it]


Epoch [8/10] - Training loss: 0.1419


Validation: 100%|██████████| 10/10 [00:49<00:00,  4.94s/it]


Epoch [8/10] - Validation Accuracy: 71.57%
Epoch [9/10] - Training started


Training Epoch 9/10: 100%|██████████| 39/39 [03:40<00:00,  5.64s/it]


Epoch [9/10] - Training loss: 0.0901


Validation: 100%|██████████| 10/10 [00:54<00:00,  5.48s/it]


Epoch [9/10] - Validation Accuracy: 56.86%
Epoch [10/10] - Training started


Training Epoch 10/10: 100%|██████████| 39/39 [03:37<00:00,  5.58s/it]


Epoch [10/10] - Training loss: 0.0345


Validation: 100%|██████████| 10/10 [00:50<00:00,  5.03s/it]

Epoch [10/10] - Validation Accuracy: 70.26%





Epoch [1/10] - Training started


Training Epoch 1/10: 100%|██████████| 39/39 [03:29<00:00,  5.36s/it]


Epoch [1/10] - Training loss: 0.0101


Validation: 100%|██████████| 10/10 [00:47<00:00,  4.80s/it]


Epoch [1/10] - Validation Accuracy: 68.30%
Epoch [2/10] - Training started


Training Epoch 2/10:   0%|          | 0/39 [00:01<?, ?it/s]


KeyboardInterrupt: 

In [9]:
model.eval()
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch in val_loader:
        images = batch['images'].to(device)
        labels = torch.tensor([1 if label == 'healthy' else 0 for label in batch['labels']], dtype=torch.long).to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())


In [13]:
import pandas as pd
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(all_labels, all_predictions)
cm_df = pd.DataFrame(cm, index=['Bleached', 'Healthy'], columns=['Bleached', 'Healthy'])

print("Confusion Matrix (as DataFrame):")
print(cm_df)

Confusion Matrix (as DataFrame):
          Bleached  Healthy
Bleached        24       65
Healthy         32      185
