In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import models
from torchvision.transforms import ToTensor
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision import transforms
import matplotlib.pyplot as plt
import cv2

In [2]:
def split_dataset(csv_file):
    data = pd.read_csv(csv_file)

    train_set = list()
    test_set = list()

    # Add proportionate data from each category to maintain the ratio of the overall dataset
    for label in data['label'].unique():
        label_data = data[data['label'] == label]

        train_label_data, test_label_data = train_test_split(label_data, test_size=0.3, random_state=42)

        # Append the splits
        train_set.append(train_label_data)
        test_set.append(test_label_data)

    train_df = pd.concat(train_set)
    test_df = pd.concat(test_set)

    print("Training Set:")
    print(train_df, end="\n\n")
    print("Test Set:")
    print(test_df)

    return train_df, test_df

train_df, test_df = split_dataset('images/Curacao Coral Reef Assessment 2023 CUR/metadata.csv')

Training Set:
                                name        date location     label  \
657   CUR_CRA_040_20231018_T4_16.JPG  2023-10-18  Curacao   healthy   
681   CUR_CRA_049_20231017_T3_14.JPG  2023-10-17  Curacao   healthy   
729   CUR_CRA_060_20231016_T1_15.JPG  2023-10-16  Curacao   healthy   
1092  CUR_CRA_110_20231031_T1_02.JPG  2023-10-31  Curacao   healthy   
594   CUR_CRA_033_20231024_T3_05.JPG  2023-10-24  Curacao   healthy   
...                              ...         ...      ...       ...   
373   CUR_CRA_017_20231027_T5_13.JPG  2023-10-27  Curacao  bleached   
859   CUR_CRA_080_20231103_T2_02.JPG  2023-11-03  Curacao  bleached   
1025  CUR_CRA_100_20231101_T4_03.JPG  2023-11-01  Curacao  bleached   
1279  CUR_CRA_119_20231023_T4_12.JPG  2023-10-23  Curacao  bleached   
368   CUR_CRA_017_20231027_T5_01.JPG  2023-10-27  Curacao  bleached   

     CoralReefWatch location  SST@90th_HS  
657              abc_islands        29.95  
681              abc_islands        30.21  
7

In [3]:
class CoralDataset(Dataset):
    def __init__(self, df, transform=None):
        self.transform = transform
        self.image_files = df['name'].tolist()
        self.ssts = df["SST@90th_HS"].to_numpy(dtype=np.float32)
        self.labels = df["label"].tolist()

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = "{}/images/Curacao Coral Reef Assessment 2023 CUR/{}/{}".format(os.getcwd(), self.labels[idx], self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        sst = self.ssts[idx]

        if self.transform:
            image = self.transform(image)

        return image, sst

# Resize images to match FashionMNIST format
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

training_data = CoralDataset(train_df, transform=transform)
test_data = CoralDataset(test_df, transform=transform)

In [9]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

class ResNetWithFC(nn.Module):
    def __init__(self):
        super(ResNetWithFC, self).__init__()

        # Load a pre-trained ResNet model (e.g., ResNet18)
        self.resnet = models.resnet18(pretrained=True)

        # Modify the final fully connected layer
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, 1)

    def forward(self, x):
        x = self.resnet(x)
        return x

Using cpu device


In [10]:
loss_fn = nn.MSELoss()

In [24]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X).squeeze()
        print(pred, y)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            # Commenting this out to reduce clutter in output
            # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [25]:
def test(dataloader, model, loss_fn, best_loss, best_model):
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X).squeeze()
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches

    print(f"Test Error: Avg loss: {test_loss:>8f}")

    if test_loss<best_loss:
      print("Found new best model with average test loss {}\n".format(test_loss))
      best_loss = test_loss
      best_model = model

    return best_model, best_loss

In [26]:
epochs = 50

def train_model(lr, batch_size=32):
  # Initialization for best loss set to a high value
  best_loss = 100000
  best_model = None

  model = ResNetWithFC().to(device)

  optimizer = torch.optim.SGD(model.parameters(), lr=lr)

  train_dataloader = DataLoader(training_data, batch_size=batch_size)
  test_dataloader = DataLoader(test_data, batch_size=batch_size)

  for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)

    best_model, best_loss = test(test_dataloader, model, loss_fn, best_loss, best_model)

  print("Done!")

  return best_model, test_dataloader

In [27]:
best_model, test_dataloader = train_model(0.001)

Epoch 1
-------------------------------
tensor([-0.8937,  0.1686,  0.2749, -1.0925, -0.5593, -0.5054, -0.5729, -0.1538,
        -1.0839, -0.4300, -0.3997, -0.6546, -0.2019, -0.0750, -0.4182,  0.0831,
        -0.4167, -0.7566, -0.7901, -0.2507,  0.8245, -0.2011,  0.2457, -0.4136,
        -0.5468, -0.1922, -0.5551,  0.0629, -0.4411, -0.4952, -0.1535, -0.5184],
       grad_fn=<SqueezeBackward0>) tensor([29.9500, 30.2100, 29.8600, 29.9500, 30.0900, 30.2100, 30.2000, 30.1800,
        30.2000, 30.2700, 30.1800, 30.1800, 29.9300, 30.1200, 30.2000, 30.1200,
        29.9600, 30.2000, 30.2000, 29.8500, 30.1200, 30.1200, 30.2100, 29.9600,
        29.9300, 30.1800, 29.9600, 30.0200, 29.9600, 30.2700, 30.2000, 30.1800])
tensor([23.3224, 25.5949, 23.1670, 25.3952, 23.5958, 23.4440, 27.9381, 22.3106,
        23.4160, 27.5503, 26.2615, 25.3041, 25.0339, 24.1332, 30.9447, 20.2273,
        22.9862, 23.6702, 22.6703, 25.4975, 27.2241, 24.2516, 26.0247, 20.4091,
        24.9624, 24.2757, 24.1092, 25.3664,

KeyboardInterrupt: 