In [64]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from argparse import Namespace

In [65]:
class Statistics:
    def __init__(self):
        self.metrics = dict()

    def update(self, metric_name, new_value):
        if metric_name in self.metrics:
            values = self.metrics[metric_name]
            values.append(new_value)

        else:
            values = [new_value]
            self.metrics.update({metric_name : values})

    def get_metric(self, metric_name):
        return self.metrics.get(metric_name)
    
    def epoch_count(self, metric_name = None):
        # MAX number of epoch in metrics
        if metric_name is None:
            max = 0
            for val in self.metrics.values():
                if len(val) > max:
                    max = len(val)

            return max 

        # Epoch count for metric name
        else:
            if metric_name in self.metrics:
                return len(self.metrics[metric_name])
            else: 
                return 0
     
    # First epoch is 0
    def epoch_metrics(self, epoch_num):
        result = dict()
        
        for metric_name, values in self.metrics.items():
            if (epoch_num >= 0) and (epoch_num < len(values)):
                metric_val = values[epoch_num]
                result.update({metric_name : metric_val})

        return result
    
    def metric_average(self, metric_name):
        if metric_name in self.metrics:
            values = self.metrics[metric_name]
            return float(sum(values) / len(values))
        
        else: 
            return None

In [66]:
def pixel_accuracy(prediction, mask):
    with torch.no_grad():
        pixel_count = float(mask.numel())

        prediction = torch.argmax(torch.softmax(prediction, dim=1), dim=1)
        correct = torch.eq(prediction, mask).int()

        accuracy = float(correct.sum()) / pixel_count

    return accuracy

def IoU(prediction, mask):
    pass

def dice(prediction, mask):
    pass

In [67]:
class Trainer:
    def __init__(self, config: Namespace, model: nn.Module):

        # Config and its parameters
        try:
            lrate = config.learning_rate
            (beta1, beta2) = config.betas
            wd = config.weight_decay
            self.batch_size = config.batch_size
        except AttributeError as e:
            raise Exception(f'Parameter "{e.name}" NOT found!')

        # Select GPU device
        self.device = (
            "cuda" if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available()
            else "cpu"
        )
        print(f"Using {self.device} device for training")

        # Move model to available device
        self.network = model.to(self.device)

        # Optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lrate, betas=(beta1, beta2), weight_decay=wd)

        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # Statistics (metrics for each epoch)
        self.stats = Statistics()

    # Create Data Loaders
    def load_dataset(self, train_data, val_data, test_data):
        self.train_data = DataLoader(train_data, batch_size= self.batch_size, shuffle= True)
        self.val_data = DataLoader(val_data, batch_size= self.batch_size, shuffle= True)
        self.test_data = DataLoader(test_data, batch_size= self.batch_size, shuffle= False)

    # Forward Pass - create prediction and return its error (loss)
    def forward_pass(self, input, ground_truth):
        prediction = self.model(input)
        loss = self.loss_fn(prediction, ground_truth)
        return loss
    
    # Backward Pass - update parameters (weights, bias)
    def backward_pass(self, loss_value):
        self.optimizer.zero_grad()
        loss_value.backward()
        self.optimizer.step()

    def train(self, logger = None):
        self.model.train()

        # Train model on each dataset batch (train_data)
        for x, y in self.train_data:
            x, y = x.to(self.device), y.to(self.device)

            loss = self.forward_pass(x, y)

            # Save batch loss
            self.stats.update("train_loss", loss)

            self.backward_pass(loss)
        
    def evaluate(self):
        self.model.eval()

        with torch.no_grad():
            for x, y in self.val_data:
                x, y = x.to(self.device), y.to(self.device)

                loss = self.forward_pass(x, y)

                # Save batch loss
                self.stats.update("val_loss", loss)

    def test_model(self):
        self.model.eval()

        with torch.no_grad():
            for x, y in self.test_data:
                x, y = x.to(self.device), y.to(self.device)

                pred = self.model(x)

                # TODO - add metrics
                
    


In [68]:
# path_images = tensor_images.pt
# path_labels = tensor_labels.pt

class Lizard_dataset(Dataset):
    def __init__(self, path_images, path_labels):
        self.images = []
        self.labels = []

        # Load images and labels as tensors
        images_t = torch.load(path_images)
        labels_t = torch.load(path_labels)

        img_count = images_t.size(dim=0)
        num_channels = images_t.size(dim=1)
        height = images_t.size(dim=2)
        width = images_t.size(dim=3)

        if ( img_count != labels_t.size(dim=0) or 
            num_channels != labels_t.size(dim=1) or 
            height != labels_t.size(dim=2) or 
            width != labels_t.size(dim=3)):

            print("Wrong tensor shapes!")
            print (f"- Images tensor shape : '{images_t.shape}'")
            print (f"- Labels tensor shape : '{labels_t.shape}'")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return (self.images[idx], self.labels[idx])

In [69]:
import net_config as cfg

x = torch.load('tensor_images.pt')
y = torch.load('tensor_labels.pt')

print(x.dtype)
print(y.dtype)

data = Lizard_dataset("tensor_images.pt", "tensor_labels.pt")

print(y)

print(cfg.a)
print(cfg.config_Unet)

t = Trainer(cfg.config_Unet, None)


torch.float64
torch.uint8
tensor([[[[208, 219, 175,  ..., 201, 221, 231],
          [214, 226, 167,  ..., 195, 215, 213],
          [186, 197, 154,  ..., 186, 192, 198],
          ...,
          [200, 192, 204,  ..., 207, 221, 234],
          [199, 202, 210,  ..., 234, 246, 245],
          [209, 209, 212,  ..., 252, 245, 223]],

         [[183, 195, 149,  ..., 145, 170, 180],
          [189, 202, 141,  ..., 144, 166, 165],
          [160, 172, 128,  ..., 141, 149, 157],
          ...,
          [155, 146, 152,  ..., 185, 201, 215],
          [152, 154, 159,  ..., 216, 229, 228],
          [168, 169, 172,  ..., 238, 230, 207]],

         [[234, 239, 207,  ..., 198, 222, 234],
          [244, 246, 201,  ..., 199, 222, 223],
          [226, 231, 194,  ..., 205, 211, 216],
          ...,
          [219, 209, 215,  ..., 222, 233, 244],
          [212, 214, 219,  ..., 243, 252, 252],
          [217, 216, 219,  ..., 254, 252, 236]]],


        [[[202, 203, 197,  ..., 237, 240, 232],
         

AttributeError: 'NoneType' object has no attribute 'to'