In [None]:
import torch
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

import argparse
import os
import math 
import skimage
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import time
import pickle

from datetime import datetime
from pathlib import Path

# from data_classes.py_files.custom_datasets import *
# from data_classes.py_files.data_classes import *
from py_files.new_dataset import *

from py_files.cnn_model import *
from py_files.pigan_model import *

from py_files.pi_gan_functions import *

# %matplotlib qt

In [None]:
def show_runs(requirements, print_vars, last=None):
    runs = sorted(os.listdir(path='saved_runs'))
    
    runs = [run for run in runs if run != "old" and run !="saved"]
    if last: 
        runs = runs[-last:]
        
    for run in runs:
        ARGS = load_args(run, print_changed=False)

        if all([requirements[key] == vars(ARGS)[key] for key in requirements.keys()]):

            print(f"\n{run}\n")

            if print_vars:
                for key in print_vars:
                    if key in vars(ARGS).keys():
                        print(f"{key}: {vars(ARGS)[key]}")
                    else: 
                        print(f"{key} not in ARGS")
            else:
                for key, item in vars(ARGS).items():
                    print(f"{key}: {item}")

            print()

            mask_losses = pcmra_losses = dice_losses = None

            if os.path.exists(f'saved_runs/{run}/mask_loss.npy'):
                mask_losses = np.load(f'saved_runs/{run}/mask_loss.npy')

            if os.path.exists(f'saved_runs/{run}/pcmra_loss.npy'):
                pcmra_losses = np.load(f'saved_runs/{run}/pcmra_loss.npy')

            if os.path.exists(f'saved_runs/{run}/dice_loss.npy'):
                dice_losses = np.load(f'saved_runs/{run}/dice_loss.npy')

            fig, axes = plt.subplots(1, 3, figsize=(18,5))
            fig.patch.set_facecolor('white')

            if type(mask_losses) == np.ndarray:
                axes[0].plot(mask_losses[1:, 0], mask_losses[1:, 1], label='Train loss')
                axes[0].plot(mask_losses[1:, 0], mask_losses[1:, 3], label='Eval loss')
                axes[0].set_title('Mask Loss')

                i1, i2 = mask_losses[:, 1].argmin(), mask_losses[:, 3].argmin()
                print(f"Lowest train mask loss at epoch {int(mask_losses[i1, 0])}:\t{round(mask_losses[i1, 1], 6)}")
                print(f"Lowest eval  mask loss at epoch {int(mask_losses[i2, 0])}:\t{round(mask_losses[i2, 3], 6)}")
                print()

            if type(pcmra_losses)  == np.ndarray:
                axes[1].plot(pcmra_losses[1:, 0], pcmra_losses[1:, 1], label='Train loss')
                axes[1].plot(pcmra_losses[1:, 0], pcmra_losses[1:, 3], label='Eval loss')
                axes[1].set_title('PCMRA Loss')
                i5, i6 = pcmra_losses[:, 1].argmin(), pcmra_losses[:, 3].argmin()
                print(f"Lowest train pcmra loss at epoch {int(pcmra_losses[i5, 0])}:\t{round(pcmra_losses[i5, 1], 6)}")
                print(f"Lowest eval  pcmra loss at epoch {int(pcmra_losses[i6, 0])}:\t{round(pcmra_losses[i6, 3], 6)}")
                print()
            if type(dice_losses) == np.ndarray:
                axes[2].plot(dice_losses[1:, 0], dice_losses[1:, 1], label='Train loss')
                axes[2].plot(dice_losses[1:, 0], dice_losses[1:, 3], label='Eval loss')
                axes[2].set_title('Dice Loss')

                i3, i4 = dice_losses[:, 1].argmin(), dice_losses[:, 3].argmin()
                print(f"Lowest train dice loss at epoch {int(dice_losses[i3, 0])}:\t{round(dice_losses[i3, 1], 6)}")
                print(f"Lowest eval  dice loss at epoch {int(dice_losses[i4, 0])}:\t{round(dice_losses[i4, 3], 6)}")


            plt.show()

            print("\n\n\n")


In [None]:
# requirements = {"cnn_setup": -6}
requirements = {}

# print_vars = None
print_vars = ["cnn_setup", "mapping_setup", "dim_hidden", "siren_hidden_layers", 
              "dataset", "pretrained", "pretrained_lr_reset", "min_lr", 
              "pcmra_first_omega_0", "first_omega_0", "translate_max_pixels", 
              "mask_train_cnn", "siren_wd", "patience"]

show_runs(requirements, print_vars, last=None)