## I. Load Libraries

In [23]:
import torch

In [24]:
from torchvision.models import mobilenet_v2
import torch.nn.functional as F # for bce loss function
from torch.utils.data import DataLoader
import torch.nn as nn

#import matplotlib.pyplot as plt
import numpy as np
from typing import Callable, Tuple, List
from tqdm import tqdm

     
BATCH_SIZE = 32
INPUT_CHANNELS = 12
RES_BLOCK_INPUT_CHANNELS = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


## II. Data Prep
## III. Data Visualization Methods
do we need these??? ^

# Building Testing
## IV. Evaluation Metrics: done
- IoU: Intersection over Union. measures overlap between prediction region and actual (ground truth) region. Scores from 0 to 1, where 1 is perfect overlap and 0 is no overlap. Usually used to determine whether bounding box is correct
- Recall
- Precision

In [25]:
"""
IoU metric
Calculation of intersection over union metric.
    
Args:
    real_mask (Tensor): Ground-truth mask
    predicted_mask (Tensor): Mask predicted by model
Returns:
    (float): IoU metric value

CHANGES: 
    - changed "tf" to "torch" in method header and real_mask line
    - added comments
"""
def IoU_metric(real_mask: torch.Tensor, predicted_mask: torch.Tensor) -> float: 
    # replacing neg values: torch.where(condition, choose-True, choose-False)
    # when the value is pos (>=0), keep the value from real_mastorch. otherwise, replace with 0
    real_mask = torch.where(real_mask>=0, real_mask, 0)

    # calculates the intersection and union between real and predicted by using a log AND and OR functions from numpy
    intersection = np.logical_and(real_mask, predicted_mask)
    union = np.logical_or(real_mask, predicted_mask)

    # if there is no object in either mask (both are entirely 0s), return 1 since IoU for 
    # empty masks would be perfect
    if np.sum(union) == 0:
        return 1
    
    # else, calculate and return intersection over union (IoU)
    return np.sum(intersection) / np.sum(union)


"""
Calculation of recall metric.
    
Args:
    real_mask (Tensor): Ground-truth mask
    predicted_mask (Tensor): Mask predicted by model
Returns:
    (float): recall metric value

CHANGES:
    - changed tf to torch
"""
def recall_metric(real_mask: torch.Tensor, predicted_mask: torch.Tensor) -> float:

    real_mask = torch.where(real_mask < 0, 0, real_mask)
    
    true_positives = np.sum(np.logical_and(real_mask, predicted_mask))
    actual_positives = np.sum(real_mask)
    if actual_positives == 0:
        return 1
    
    return true_positives / actual_positives

"""
Calculation of precision metric.
    
Args:
    real_mask (Tensor): Ground-truth mask
    predicted_mask (Tensor): Mask predicted by model
Returns:
    (float): precision metric value

CHANGES:
    - changed tf to torch
"""
def precision_metric(real_mask: torch.Tensor, predicted_mask: torch.Tensor) -> float:
    real_mask = torch.where(real_mask < 0, 0, real_mask)
    
    true_positives = np.sum(np.logical_and(real_mask, predicted_mask))
    predicted_positives = np.sum(predicted_mask)
    if predicted_positives == 0:
        return 1
    
    return true_positives / predicted_positives

## V. Loss Functions: done
- dice coefficient: metric used to evaluate similarity between sets, particularly in image segmentation. Dice coeff is calculated to be 2 times the intersection of the ground truth and predicted, over ground truth plus predicted. 
$$
\text{Dice Coefficient} = \frac{2 \times |A \cap B|}{|A| + |B|}
$$

- weighted binary cross entropy

- BCE Dice loss 
regular dice loss:  
$$ 
\text{Dice Loss} = 1 - \text{Dice Coefficient}
$$  
BCE dice loss adds dice loss and wBCE loss

In [26]:
"""
Dice loss function calculator.
    
Args:
    y_true (Tensor): 
    y_pred (Tensor):
Returns:
    (Tensor): Dice loss for each element of a batch.

CHANGES:
    - changed tf to torch in method header
    - changed K to torch throughout
"""
def dice_coef(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
    smooth = 1e-6
    y_true_f = torch.reshape(y_true, (BATCH_SIZE, -1))
    y_pred_f = torch.reshape(y_pred, (BATCH_SIZE, -1))
    intersection = torch.sum(y_true_f * y_pred_f, axis=1)
    return 1 - (2. * intersection + smooth) / (torch.sum(y_true_f, axis=1) + torch.sum(y_pred_f, axis=1) + smooth)


"""
Calculates weighted binary cross entropy. The weights are fixed.
    
This can be useful for unbalanced catagories.

Adjust the weights here depending on what is required.

For example if there are 10x as many positive classes as negative classes,
    if you adjust weight_zero = 1.0, weight_one = 0.1, then false positives 
    will be penalize 10 times as much as false negatives.

Args:
    true (Tensor): Ground-truth values
    pred (Tensor): Predited values
    weight_zero (float): Weight of class 0 (no-fire)
    weight_one (float): Weight of class 1 (fire)

Returns: 
    (float) : value for weighted binary cross entropy
CHANGES:
    - changed tf to torch in method header
    - changed K to torch throughout
    - changed keras BCE method to torch.nn.functional.binary_cross_entropy
    
"""
def weighted_bincrossentropy(true: torch.Tensor, pred: torch.Tensor, weight_zero: float = 0.01, weight_one: float = 1) -> float:
  
    # calculate the binary cross entropy
    # using torch.nn.functional.binary_cross_entropy, set reduction='none' to keep individual losses in a tensor
    # rather than taking mean 
    bin_crossentropy = F.binary_cross_entropy(true, pred, reduction='none')
    
    # apply the weights
    weights = true * weight_one + (1. - true) * weight_zero
    weighted_bin_crossentropy = weights * bin_crossentropy 
    
    return torch.mean(weighted_bin_crossentropy, axis=1)


"""
BCE loss function calculator.

Args:
    y_true (Tensor): 
    y_pred (Tensor):
Returns:
    (Tensor): Mean BCE Dice loss over a batch.

CHANGES:
    - changed tf to torch
    - originally returned keras.reduce_weighted_loss(loss), but without additional args all that did was 
      perform a sum operation. Replaced it with torch.sum since there is no torch equivalent to reduce_weighted_loss
"""
def bce_dice_loss(y_true: torch.Tensor, y_pred: torch.Tensor):    
    y_true_f = torch.reshape(y_true, (BATCH_SIZE, -1))
    y_pred_f = torch.reshape(y_pred, (BATCH_SIZE, -1))
    return torch.sum(weighted_bincrossentropy(y_true_f, y_pred_f) + dice_coef(y_true, y_pred))

# VI. Evaluation Loop: done

In [27]:
"""
Loads dataset according to file pattern and evaluates model's predictions on it.

Parameters:
    model (Callable[[tf.Tensor], tf.Tensor]): Function for model inference.
    eval_dataset (tf.dataDataset): Dataset for evaluation.

Returns:
    Tuple[float, float, float, float]: IoU score, recall score, precision score and mean loss.

CHANGES:
    - changed tf to torch
    - in method header, imported DataLoader from torch.utils and changed  eval_dataset: tf.data.Dataset) to DataLoader
    - changed tf.expand_dims(tf.cast(predictions, tf.float32), axis=-1) to predictions.float().unsqueeze(-1)
    in losses.append
"""
def evaluate_model(prediction_function: Callable[[torch.Tensor], torch.Tensor],
                   eval_dataset: DataLoader) -> Tuple[float, float, float, float]:
    IoU_measures = []
    recall_measures = []
    precision_measures = []
    losses = []
    
    for inputs, labels in tqdm(eval_dataset):
        # Prediction shape (N, W, H)
        predictions = prediction_function(inputs)
        for i in range(inputs.shape[0]):
            IoU_measures.append(IoU_metric(labels[i, :, :,  0], predictions[i, :, :]))
            recall_measures.append(recall_metric(labels[i, :, :,  0], predictions[i, :, :]))
            precision_measures.append(precision_metric(labels[i, :, :,  0], predictions[i, :, :]))
        labels_cleared = torch.where(labels < 0, 0, labels)
        losses.append(bce_dice_loss(labels_cleared, predictions.float().unsqueeze(-1)
))
            
    mean_IoU = np.mean(IoU_measures)
    mean_recall = np.mean(recall_measures)
    mean_precision = np.mean(precision_measures)
    mean_loss = np.mean(losses)
    return mean_IoU, mean_recall, mean_precision, mean_loss

# Modeling Technique
## VII. Torch Model: done

The layers chosen for skips from NDWS were:  
       a. 'block_1_expand_relu',   
       b. 'block_3_expand_relu',   
       c. 'block_6_expand_relu',   
       d. 'block_13_expand_relu',  
       e. 'block_16_project',  

which matched to the following dimensions (from model.summary()):  
        a. (16, 16, 96)  
        b. (8, 8, 144)  
        c. (4, 4, 192)  
        d. (2, 2, 576)  
        e. (1, 1, 320)  
  
equivalent layers in torch model, after adjusting model for 12   channel output, were (using forward hook to print sizes of   submodules):  
        a. features[2].conv[0] -> (96, 16, 16)  
        b. features[3].conv[0] -> (144, 8, 8)  
        c. features[6].conv[0] -> (192, 4, 4)  
        d. features[13].conv[0] -> (576, 2, 2)  
        e. features[17] -> (320, 1, 1)  



In [28]:
# replacing: 
# tf.keras.applications.MobileNetV2(input_shape=[32, 32, 12], include_top=False, weights=None)
class Downstack(nn.Module):
    def __init__(self, input_channels, kernel_size, stride, padding):
        super().__init__() # initialize the class as a pytorch module

        self.base_model = mobilenet_v2(weights=None) # torch imp only takes weights as parameter
        # torch auto runs the classifier in the forward pass so have to replace those layers with identity to
        # get equivalent to include_top=False
        self.base_model.classifier = nn.Identity() 

        # changing default input size from 3 (RGB) to 12 for wildfire images
        self.base_model.features[0][0] = nn.Conv2d(
            in_channels = input_channels,
            out_channels = RES_BLOCK_INPUT_CHANNELS,
            kernel_size = kernel_size,
            stride = stride,
            padding = padding,
            bias = False
        )

        # replacing layer_names and base_model_outputs in keras
        # since we needed to grab outputs at submodule layers and only needed 5 rn, 
        # easiest way was to use a forward hook fnc and manually register each layer
        self.skips= []
        def hook(module, input, output):
            self.skips.append(output)

        self.base_model.features[2].conv[0].register_forward_hook(hook) # eqv 'block_1_expand_relu'
        self.base_model.features[3].conv[0].register_forward_hook(hook) # eqv 'block_3_expand_relu'
        self.base_model.features[6].conv[0].register_forward_hook(hook) # eqv 'block_6_expand_relu'
        self.base_model.features[13].conv[0].register_forward_hook(hook) # eqv 'block_13_expand_relu'
        self.base_model.features[17].register_forward_hook(hook) # eqv 'block_16_project'

        
    def forward(self, x):
        self.base_model(x)
        return self.skips

In [29]:
# to create upstack: replacing pix2pix.upsample layers
# pix2pix layers are just a conv transpose layer, a batchnorm layer, and a relu layer 
# (dropout optional but not used in NDWS model)
def upsample(input, output, kernel_size=3, stride=2, padding=1):
    block = nn.Sequential(
        nn.ConvTranspose2d(in_channels=input, out_channels=output, kernel_size=kernel_size, 
                           stride=stride, padding=padding, bias=False),
        nn.BatchNorm2d(output),
        nn.ReLU()
    )
    nn.init.normal_(block[0].weight, mean=0.0, std=0.02) #initializer?
    return block    

In [30]:
# putting the downstack and upsampler together to make u-net like 
# conv autoencoder structure
class convAutoencoder(nn.Module):
    def __init__(self, input_channels):
        super.__init__() # initialize as pytorch module

        self.downstack = Downstack(input_channels, kernel_size=3, stride=2, padding=1)

        self.upstack = nn.Sequential(
                    upsample(1280, 512),
                    upsample(512, 256),
                    upsample(256, 128),
                    upsample(128, 64)
        )

    def forward(self, x):
        skips = self.downstack(x)
        x = skips[-1] # last layer of skips is bottleneck; where upsampler starts
        skips = reversed(skips[:-1]) # rearrange from deep->shallow, dropping the bottleneck
        
        # concatenate outputs of upstack and skips along channel dimension
        for up, skip in zip(self.upstack, skips):
            x = up(x)
            x = torch.cat([x, skip], dim = 1) # dim = 1 is channels acc to torch ordering
            
        return x


## VIII. Torch Training loop
 

not finished converting! changes made so far:  
- tf.data.Dataset to DataLoader in method header
- tf.keras.optimizers.Adam() to torch.optim.Adam() for optimizer
- tf to torch for .where()

In [31]:
"""
Trains a model using train dataset. (Save weights of model with best IoU)

Args:
    model (Model): Model to train.
    train_dataset (Dataset): Training dataset.
    epochs (int): Number of epochs
Returns:
    Tuple[List[float], List[float]]: Train losses and Validation losses

CHANGES:
    - tf.data.Dataset to DataLoader in method header
    - tf.keras.optimizers.Adam() to torch.optim.Adam() for optimizer
    - tf to torch for .where()
    - removed GradientTape() call, replace all gradient stuff with "optimizer.zero_grad(), loss.backward(), optimizer.step()"
"""

def train_model(model: Model, train_dataset: DataLoader, epochs:int=10) -> Tuple[List[float], List[float]]:
    loss_fn = bce_dice_loss
    optimizer = torch.optim.Adam()
    batch_losses = []
    val_losses = []
    best_IoU = 0.0
    
    for epoch in range(epochs):
        losses = []
        print(f'Epoch {epoch+1}/{epochs}')
        # Iterate through the dataset
        progress = tqdm(train_dataset)
        for images, masks in progress:
            #with tf.GradientTape() as tape:

            # Forward pass
            predictions = model(images, training=True)
            label = torch.where(masks < 0, 0, masks)
            # Compute the loss
            loss = loss_fn(label, predictions)
            losses.append(loss.numpy())
            progress.set_postfix({'batch_loss': loss.numpy()})

            # Compute gradients
            optimizer.zero_grad()
            loss.backware()
            optimizer.step()
            #gradients = tape.gradient(loss, model.trainable_variables)
            # Update the model's weights
            #optimizer.apply_gradients(zip(gradients, model.trainable_variables)) <-- what here?

        # Evaluate model
        print("Evaluation...")
        IoU, recall, precision, val_loss = evaluate_model(lambda x: torch.where(model.predict(x) > 0.5, 1, 0)[:,:,:,0], validation_dataset)
        print("Validation set metrics:")
        print(f"Mean IoU: {IoU}\nMean precision: {precision}\nMean recall: {recall}\nValidation loss: {val_loss}\n")
        # Save best model
        if IoU > best_IoU:
            best_IoU = IoU
            model.save_weights("best.h5")
        
        # Print the loss for monitoring
        print(f'Epoch: {epoch}, Train loss: {np.mean(losses)}')
        batch_losses.append(np.mean(losses))
        val_losses.append(val_loss)
    
    print(f"Best model IoU: {best_IoU}")
    return batch_losses, val_losses

# Set reproducability
tf.random.set_seed(1337)

segmentation_model = build_CNN_AE_model()
train_losses, val_losses = train_model(segmentation_model, train_dataset, epochs=15)

NameError: name 'Model' is not defined

## IX. Plot Loss functions

# Testing
## X. Metrics on test set

## XI. Comparision with statistical model

## XII. Inference on test set