# __WanDB Experiment__
This file connects _models.py_ and _trainer.py_ files and manages experiments created in wanDB. It also contains dataset reresentation as Dataset subclass (Lizard_dataset). Experiments are defined in file NN-z2 (main file).


In [2]:
import wandb
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision.transforms as T
import matplotlib.pyplot as plt
import gc
import os.path

In [3]:
import net_config as cfg
#from models import *
#from trainer import *

## wanDB run class

This class executes training epochs by calling trainer functions. It also logs metrics and decides when the model params are saved (locally).
This class contains: 
- Current wanDB run 
- Trainer
- Save interval (every n-th epoch)


In [4]:
class wanDB_run: 
    def __init__(self, run_name, run_id, model: nn.Module, save_interval = None):
        wandb.login()
        
        wandb.finish()
        
        self.run = wandb.init(
        entity = cfg.project_entity, 
        project = cfg.project_name,     
        name = run_name, 
        id = run_id
        )

        wandb.config = cfg.config_to_dict(cfg.config_Unet)

        self.trainer = Trainer(model)
        self.save_interval = save_interval
        self.datasets_loaded = False
        self.batch_count = 0

        # Load best model
        if (self.save_interval is not None) and os.path.isfile(cfg.model_path):
            self.current_epoch = self.trainer.load_model()
        else:
            self.current_epoch = 0

    def load_datasets(self, train_pathX, train_pathY, val_pathX, val_pathY, test_pathX, test_pathY):
        #self.trainer.load_dataset(trainData, valData, testData)
        self.datasets_loaded = True

    
    def execute_training(self, epoch_count, log_batch = False):
        assert self.datasets_loaded, "Datasets are NOT loaded"

        for _ in range(epoch_count):
            self.current_epoch += 1
            print(f"--Starting epoch {self.current_epoch}--")

            # Train model
            self.trainer.train_model()
            # Evaluate model
            self.trainer.evaluate_model()

            if log_batch:
                for i in range(self.trainer.stats.batch_count()):
                    self.batch_count += 1
                    m = self.trainer.stats.batch_metrics(i)

                    self.run.log({"loss_train": m.get(cfg.metric_name_Tloss), "batch": self.batch_count})
                    self.run.log({"loss_val": m.get(cfg.metric_name_Vloss), "batch": self.batch_count})
                    self.run.log({"accuracy": m.get(cfg.metric_name_acc), "batch": self.batch_count})
                    self.run.log({"iou": m.get(cfg.metric_name_iou), "batch": self.batch_count})
                    self.run.log({"dice": m.get(cfg.metric_name_dice), "batch": self.batch_count})

            
            else:
                # Get metrics average
                tl = self.trainer.stats.metric_average(cfg.metric_name_Tloss)
                vl = self.trainer.stats.metric_average(cfg.metric_name_Vloss)
                acc = self.trainer.stats.metric_average(cfg.metric_name_acc)
                iou = self.trainer.stats.metric_average(cfg.metric_name_iou)
                dice = self.trainer.stats.metric_average(cfg.metric_name_dice)

                # Save metrics to wandb
                self.run.log({"loss_train": tl, "epoch": self.current_epoch})
                self.run.log({"loss_val": vl, "epoch": self.current_epoch})
                self.run.log({"accuracy": acc, "epoch": self.current_epoch})
                self.run.log({"iou": iou, "epoch": self.current_epoch})
                self.run.log({"dice": dice, "epoch": self.current_epoch})

            self.trainer.stats.clear()
            gc.collect()

            # Save best model
            if (self.save_interval is not None) and (self.current_epoch % self.save_interval == 0):
                self.trainer.save_model(self.current_epoch)

            print(f"--Ending epoch {self.current_epoch}--")
    
    def stop_run(self):
        self.run.finish()
        del self.trainer
        self.datasets_loaded = False

        gc.collect()


In [5]:
d = cfg.config_to_dict(cfg.config_Unet)


t = torch.rand(14, 3, 500, 500)
print(t.shape)

t = net(t)
print(t.shape)

classes = torch.argmax(t, dim = 1)
print(classes.shape)

AttributeError: module 'net_config' has no attribute 'config_Unet'