In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from argparse import Namespace
import time

In [2]:
class Statistics:
    def __init__(self):
        self.metrics = dict()

    def update(self, metric_name, new_value):
        if metric_name in self.metrics:
            values = self.metrics[metric_name]
            values.append(new_value)
        else:
            values = [new_value]
            self.metrics.update({metric_name : values})

    def get_metric(self, metric_name):
        return self.metrics.get(metric_name)
    
    def batch_count(self):
        max = 0
        for val in self.metrics.values():
            if len(val) > max:
                max = len(val)

        return max 

    # First batch is 0
    def batch_metrics(self, batch_num):
        result = dict()
        
        for metric_name, values in self.metrics.items():
            if (batch_num >= 0) and (batch_num < len(values)):
                metric_val = values[batch_num]
                result.update({metric_name : metric_val})

        return result
    
    def metric_average(self, metric_name):
        if metric_name in self.metrics:
            values = self.metrics[metric_name]
            return float(sum(values) / len(values))
        
        else: 
            return None

In [3]:
def pixel_accuracy(prediction, truth):
    with torch.no_grad():
        pixel_count = float(truth.numel())

        correct = (torch.eq(prediction, truth).int()).sum()

        accuracy = float(correct) / pixel_count

    return accuracy

def ioU(prediction, mask):
    pass

def dice(prediction, mask):
    pass

In [4]:
class Trainer:
    metric_name_Tloss = "train_loss"
    metric_name_Vloss = "val_loss"
    metric_name_acc = "accuracy"
    metric_name_IoU = "IoU"
    metric_name_dice = "dice"

    def __init__(self, model: nn.Module, config: Namespace):

        # Config and its parameters
        try:
            lrate = config.learning_rate
            (beta1, beta2) = config.betas
            wd = config.weight_decay
            self.batch_size = config.batch_size
        except AttributeError as e:
            raise Exception(f'Parameter "{e.name}" NOT found!')

        # Select GPU device
        self.device = (
            "cuda" if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available()
            else "cpu"
        )

        print(f"Using {self.device} device for training")

        # Move model to available device
        self.network = model.to(self.device)

        # Optimizer
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lrate, betas=(beta1, beta2), weight_decay=wd)

        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # Statistics (metrics for each epoch)
        self.stats = Statistics()

    # Create Data Loaders
    def load_dataset(self, train_data, val_data, test_data):
        self.train_data = DataLoader(train_data, batch_size= self.batch_size, shuffle= True)
        self.val_data = DataLoader(val_data, batch_size= self.batch_size, shuffle= True)
        self.test_data = DataLoader(test_data, batch_size= self.batch_size, shuffle= False)

    # Forward Pass - create prediction and return its error (loss)
    def forward_pass(self, input, ground_truth):
        prediction = self.network(input)
        loss = self.loss_fn(prediction, ground_truth)
        return loss
    
    # Backward Pass - update parameters (weights, bias)
    def backward_pass(self, loss_value):
        self.optimizer.zero_grad()
        loss_value.backward()
        self.optimizer.step()

    def train_model(self):
        self.network.train()

        # Train model (dataset = train_data)
        start = time.time()

        for x, y in self.train_data:
            x, y = x.to(self.device), y.to(self.device)

            loss = self.forward_pass(x, y)

            # Save batch loss
            self.stats.update(self.metric_name_Tloss, loss)

            self.backward_pass(loss)

        end = time.time()
        print(f"Train time in sec = {end - start}")

        self.network.eval()

        # Evaulate model by calculating loss (dataset = val_data)
        start = time.time()

        with torch.no_grad():
            for x, y in self.val_data:
                x, y = x.to(self.device), y.to(self.device)

                loss = self.forward_pass(x, y)

                # Save batch loss
                self.stats.update(self.metric_name_Vloss, loss)

        end = time.time()
        print(f"Validation time in sec = {end - start}")

        # Evaulate model by calculating metrics (dataset = test_data)
        start = time.time()

        with torch.no_grad():
            for x, y in self.test_data:
                x, y = x.to(self.device), y.to(self.device)

                pred = self.network(x)
                classes = torch.argmax(pred, dim = 1)

                # TODO - calculate metrics
                self.stats.update(self.metric_name_acc, pixel_accuracy(classes, y))
                self.stats.update(self.metric_name_IoU, 1)
                self.stats.update(self.metric_name_dice, 2)

        end = time.time()
        print(f"Test time in sec = {end - start}")
