In [1]:
import torch

In [None]:
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np

import torch.nn.functional as F # for bce loss function
from torch.utils.data import DataLoader
from typing import Callable, Tuple
from tqdm import tqdm

     
BATCH_SIZE = 32

## data prep
## data visualization methods
# do we need these?

## evaluation metrics (DONE)
- IoU: Intersection over Union. measures overlap between prediction region and actual (ground truth) region. Scores from 0 to 1, where 1 is perfect overlap and 0 is no overlap. Usually used to determine whether bounding box is correct
- Recall
- Precision

In [3]:
"""
IoU metric
Calculation of intersection over union metric.
    
Args:
    real_mask (Tensor): Ground-truth mask
    predicted_mask (Tensor): Mask predicted by model
Returns:
    (float): IoU metric value

CHANGES: 
    - changed "tf" to "torch" in method header and real_mask line
    - added comments
"""
def IoU_metric(real_mask: torch.Tensor, predicted_mask: torch.Tensor) -> float: 
    # replacing neg values: torch.where(condition, choose-True, choose-False)
    # when the value is pos (>=0), keep the value from real_mastorch. otherwise, replace with 0
    real_mask = torch.where(real_mask>=0, real_mask, 0)

    # calculates the intersection and union between real and predicted by using a log AND and OR functions from numpy
    intersection = np.logical_and(real_mask, predicted_mask)
    union = np.logical_or(real_mask, predicted_mask)

    # if there is no object in either mask (both are entirely 0s), return 1 since IoU for 
    # empty masks would be perfect
    if np.sum(union) == 0:
        return 1
    
    # else, calculate and return intersection over union (IoU)
    return np.sum(intersection) / np.sum(union)


"""
Calculation of recall metric.
    
Args:
    real_mask (Tensor): Ground-truth mask
    predicted_mask (Tensor): Mask predicted by model
Returns:
    (float): recall metric value

CHANGES:
    - changed tf to torch
"""
def recall_metric(real_mask: torch.Tensor, predicted_mask: torch.Tensor) -> float:

    real_mask = torch.where(real_mask < 0, 0, real_mask)
    
    true_positives = np.sum(np.logical_and(real_mask, predicted_mask))
    actual_positives = np.sum(real_mask)
    if actual_positives == 0:
        return 1
    
    return true_positives / actual_positives

"""
Calculation of precision metric.
    
Args:
    real_mask (Tensor): Ground-truth mask
    predicted_mask (Tensor): Mask predicted by model
Returns:
    (float): precision metric value

CHANGES:
    - changed tf to torch
"""
def precision_metric(real_mask: torch.Tensor, predicted_mask: torch.Tensor) -> float:
    real_mask = torch.where(real_mask < 0, 0, real_mask)
    
    true_positives = np.sum(np.logical_and(real_mask, predicted_mask))
    predicted_positives = np.sum(predicted_mask)
    if predicted_positives == 0:
        return 1
    
    return true_positives / predicted_positives

## loss functions (DONE)
- dice coefficient: metric used to evaluate similarity between sets, particularly in image segmentation. Dice coeff is calculated to be 2 times the intersection of the ground truth and predicted, over ground truth plus predicted. 
$$
\text{Dice Coefficient} = \frac{2 \times |A \cap B|}{|A| + |B|}
$$

- weighted binary cross entropy

- BCE Dice loss 
regular dice loss:  
$$ \text{Dice Loss} = 1 - \text{Dice Coefficient}$$  
BCE dice loss adds dice loss and wBCE loss

In [4]:
"""
Dice loss function calculator.
    
Args:
    y_true (Tensor): 
    y_pred (Tensor):
Returns:
    (Tensor): Dice loss for each element of a batch.

CHANGES:
    - changed tf to torch in method header
    - changed K to torch throughout
"""
def dice_coef(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
    smooth = 1e-6
    y_true_f = torch.reshape(y_true, (BATCH_SIZE, -1))
    y_pred_f = torch.reshape(y_pred, (BATCH_SIZE, -1))
    intersection = torch.sum(y_true_f * y_pred_f, axis=1)
    return 1 - (2. * intersection + smooth) / (torch.sum(y_true_f, axis=1) + torch.sum(y_pred_f, axis=1) + smooth)


"""
Calculates weighted binary cross entropy. The weights are fixed.
    
This can be useful for unbalanced catagories.

Adjust the weights here depending on what is required.

For example if there are 10x as many positive classes as negative classes,
    if you adjust weight_zero = 1.0, weight_one = 0.1, then false positives 
    will be penalize 10 times as much as false negatives.

Args:
    true (Tensor): Ground-truth values
    pred (Tensor): Predited values
    weight_zero (float): Weight of class 0 (no-fire)
    weight_one (float): Weight of class 1 (fire)

Returns: 
    (float) : value for weighted binary cross entropy
CHANGES:
    - changed tf to torch in method header
    - changed K to torch throughout
    - changed keras BCE method to torch.nn.functional.binary_cross_entropy
    
"""
def weighted_bincrossentropy(true: torch.Tensor, pred: torch.Tensor, weight_zero: float = 0.01, weight_one: float = 1) -> float:
  
    # calculate the binary cross entropy
    # using torch.nn.functional.binary_cross_entropy, set reduction='none' to keep individual losses in a tensor
    # rather than taking mean 
    bin_crossentropy = F.binary_cross_entropy(true, pred, reduction='none')
    
    # apply the weights
    weights = true * weight_one + (1. - true) * weight_zero
    weighted_bin_crossentropy = weights * bin_crossentropy 
    
    return torch.mean(weighted_bin_crossentropy, axis=1)


"""
BCE loss function calculator.

Args:
    y_true (Tensor): 
    y_pred (Tensor):
Returns:
    (Tensor): Mean BCE Dice loss over a batch.

CHANGES:
    - changed tf to torch
    - originally returned keras.reduce_weighted_loss(loss), but without additional args all that did was 
      perform a sum operation. Replaced it with torch.sum since there is no torch equivalent to reduce_weighted_loss
"""
def bce_dice_loss(y_true: torch.Tensor, y_pred: torch.Tensor):    
    y_true_f = torch.reshape(y_true, (BATCH_SIZE, -1))
    y_pred_f = torch.reshape(y_pred, (BATCH_SIZE, -1))
    return torch.sum(weighted_bincrossentropy(y_true_f, y_pred_f) + dice_coef(y_true, y_pred))

# Evaluation loop

In [None]:
"""
Loads dataset according to file pattern and evaluates model's predictions on it.

Parameters:
    model (Callable[[tf.Tensor], tf.Tensor]): Function for model inference.
    eval_dataset (tf.dataDataset): Dataset for evaluation.

Returns:
    Tuple[float, float, float, float]: IoU score, recall score, precision score and mean loss.

CHANGES:
    - changed tf to torch
    - in method header, imported DataLoader from torch.utils and changed  eval_dataset: tf.data.Dataset) to DataLoader
    - changed tf.expand_dims(tf.cast(predictions, tf.float32), axis=-1) to predictions.float().unsqueeze(-1)
    in losses.append
"""
def evaluate_model(prediction_function: Callable[[torch.Tensor], torch.Tensor],
                   eval_dataset: DataLoader) -> Tuple[float, float, float, float]:
    IoU_measures = []
    recall_measures = []
    precision_measures = []
    losses = []
    
    for inputs, labels in tqdm(eval_dataset):
        # Prediction shape (N, W, H)
        predictions = prediction_function(inputs)
        for i in range(inputs.shape[0]):
            IoU_measures.append(IoU_metric(labels[i, :, :,  0], predictions[i, :, :]))
            recall_measures.append(recall_metric(labels[i, :, :,  0], predictions[i, :, :]))
            precision_measures.append(precision_metric(labels[i, :, :,  0], predictions[i, :, :]))
        labels_cleared = torch.where(labels < 0, 0, labels)
        losses.append(bce_dice_loss(labels_cleared, predictions.float().unsqueeze(-1)
))
            
    mean_IoU = np.mean(IoU_measures)
    mean_recall = np.mean(recall_measures)
    mean_precision = np.mean(precision_measures)
    mean_loss = np.mean(losses)
    return mean_IoU, mean_recall, mean_precision, mean_loss

# Model

# Plot Loss functions

# Testing

# Metrics on test set