In [1]:
import sys
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import random_split, DataLoader
import matplotlib.pyplot as plt

sys.path.append("./source/")
import mgc_classifier as mgc_classifier
import dataset_loader as datasets


In [3]:
class train_and_validate_static:
    """
    class that implements training the network and
    outputting validation metrics
    """

    def __init__(
        self, model, datapath, criterion, optimizer, lr, batch_size, minibatch_size,
        num_workers
    ):

        # initialize the model
        self.model = model(
            out_channels=1, 
            num_conv_layers=3, 
            n_classes=40, 
            img_size=128
        )
        self.datapath = datapath
        self.optimizer = optimizer(self.model.parameters(), lr=lr)
        self.criterion = criterion

        # hyperparameters
        self.batch_size = batch_size
        self.minibatch_size = minibatch_size
        self.num_workers = num_workers

        # initialize the logging variables
        self.training_loss = []
        self.validation_loss = []
        self.validation_accuracy = []

        # training and validation dataloaders
        self.train_ratio = 0.8
        self.train_dataloader = None
        self.validation_dataloader = None

    def train(self, epochs=3):
        """
        train the network
        """

        # get the full dataset in the folder
        folder_dataset = datasets.PklDataset(self.datapath)
        print(len(folder_dataset))

        # split data into training and test

        train_data, validation_data = random_split(
            dataset=folder_dataset,
            lengths=[
                int(len(folder_dataset) * self.train_ratio),
                len(folder_dataset) - int(len(folder_dataset) * self.train_ratio),
            ],
        )
        print(len(train_data))

        # get the DataLoaders
        self.train_dataloader = DataLoader(dataset = train_data,
            batch_size = self.batch_size,
            shuffle = True, 
            num_workers = self.num_workers,
            collate_fn = datasets.PklDataset.collate_fn,
        )
        self.validation_dataloader = DataLoader(dataset = validation_data,
            batch_size = self.batch_size,
            shuffle = True, 
            num_workers = self.num_workers,
            collate_fn = datasets.PklDataset.collate_fn,
        )

        # iterate through the samples
        for batch_sample in self.train_dataloader:
            minibatch = datasets.MinibatchDataset(
                data = batch_sample,
            )
            minibatch_dataloader = DataLoader(
                dataset = minibatch,
                batch_size = self.minibatch_size,
                shuffle = True,
                num_workers = self.num_workers,
            )
            for inputs, targets in minibatch_dataloader:
                print(inputs.shape, targets.shape)


                # forward pass
                predictions = self.model(inputs)
                
                # loss 
                print(predictions.shape, targets.shape)
                loss = self.criterion(predictions, targets)
                
                
                br
            
            

    def validate(self):
        """
        validate the network predictions
        """
        return


test_model = mgc_classifier.MgcNet

root_path = Path("C:\machine_learning\MGC_classifier\data")
train_val_object = train_and_validate_static(
    model = mgc_classifier.MgcNet, 
    datapath = root_path,
    criterion = nn.BCELoss(reduction = 'none'),
    optimizer = torch.optim.Adam,
    lr = 1e-3,
    batch_size = 3, # just how much data can be loaded into memory at one time
    minibatch_size = 15, # what actually controls the batching size for training
    num_workers = 4
)

train_val_object.train()


1228
982
torch.Size([15, 1, 128, 128]) torch.Size([15, 40])
torch.Size([15, 40]) torch.Size([15, 40])


NameError: name 'br' is not defined