In [4]:
import torch
import torch.nn as nn
from argparse import Namespace

In [5]:
class Statistics:
    def __init__(self):
        self.metrics = dict()

    def update(self, metric_name, new_value):
        if metric_name in self.metrics:
            values = self.metrics[metric_name]
            values.append(new_value)

        else:
            values = [new_value]
            self.metrics.update({metric_name : values})

    def get_metric(self, metric_name):
        return self.metrics.get(metric_name)
    
    def epoch_count(self, metric_name = None):
        # MAX number of epoch in metrics
        if metric_name is None:
            max = 0
            for val in self.metrics.values():
                if len(val) > max:
                    max = len(val)

            return max 

        # Epoch count for metric name
        else:
            if metric_name in self.metrics:
                return len(self.metrics[metric_name])
            else: 
                return 0
     
    # First epoch is 0
    def epoch_metrics(self, epoch_num):
        result = dict()
        
        for metric_name, values in self.metrics.items():
            if (epoch_num >= 0) and (epoch_num < len(values)):
                metric_val = values[epoch_num]
                result.update({metric_name : metric_val})

        return result
    
    def metric_average(self, metric_name):
        if metric_name in self.metrics:
            values = self.metrics[metric_name]
            return float(sum(values) / len(values))
        
        else: 
            return None

In [6]:
def pixel_accuracy(prediction, mask):
    with torch.no_grad():
        pixel_count = float(mask.numel())

        prediction = torch.argmax(torch.softmax(prediction, dim=1), dim=1)
        correct = torch.eq(prediction, mask).int()

        accuracy = float(correct.sum()) / pixel_count

    return accuracy

def IoU(prediction, mask):
    pass

def dice(prediction, mask):
    pass

In [7]:
class Trainer:
    def __init__(self, config: Namespace, model : nn.Module):
        # Config and its parameters
        self.cfg = config

        try:
            lrate = self.cfg.learning_rate
            (beta1, beta2) = self.cfg.betas
            wd = self.cfg.weight_decay
        except AttributeError as e:
            raise Exception(f'Parameter "{e.args[0]}" NOT found!')

        # Select GPU device
        self.device = (
            "cuda" if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available()
            else "cpu"
        )
        print(f"Using {self.device} device for training")

        # Move models to available device
        self.network = model.to(self.device)

        # Optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lrate, betas=(beta1, beta2), weight_decay=wd)

        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # Statistics (metrics for each epoch)
        self.stats = Statistics()

    # Create Data Loaders
    def load_dataset(self, train_data, val_data, test_data):
        self.train_data = DataLoader(train_data, batch_size=self.cfg.batch_size, shuffle=True)
        self.val_data = DataLoader(val_data, batch_size=self.cfg.batch_size, shuffle=True)
        self.test_data = DataLoader(test_data, batch_size=self.cfg.batch_size, shuffle=True)

    # Forward Pass - create prediction and return its error (loss)
    def forward_pass(self, input, ground_truth):
        prediction = self.model(input)
        loss = self.loss_fn(prediction, ground_truth)
        return loss
    
     # Backward Pass - update parameters (weights, bias)
    def backward_pass(self, loss_value):
        self.optimizer.zero_grad()
        loss_value.backward()
        self.optimizer.step()

    def train(self, logger = None):
        self.model.train()

        # Train model on each dataset batch (train_data)
        for x, y in self.train_data:
            x, y = x.to(self.device), y.to(self.device)

            loss = self.forward_pass(x, y)

            # Save batch loss
            self.stats.update("train_loss", loss)

            self.backward_pass(loss)
        
    def evaluate(self):
        self.model.eval()

        with torch.no_grad():
            for x, y in self.val_data:
                x, y = x.to(self.device), y.to(self.device)

                loss = self.forward_pass(x, y)

                # Save batch loss
                self.stats.update("val_loss", loss)

    def test_model(self):
        with torch.no_grad():
            for x, y in self.test_data:
                x, y = x.to(self.device), y.to(self.device)

                pred = self.model(x)

                # TODO - add metrics
                
    
