# Calculate the loss as a mteric for the final model

In [2]:
import pandas as pd 
import numpy as np
import torch 
from qsc import Qsc
from tqdm import tqdm
import numpy as np
import pandas as pd
import sys
from tabulate import tabulate
import jax

from MDNFullCovariance import MDNFullCovariance
from utils import sample_output, check_criteria, run_qsc, round_nfp

## Load Datasets

### GStels

In [3]:
fname_gstels = '../data/GStels/GStels.csv'
df_gstels = pd.read_csv(fname_gstels)

# Sample 10000 to make it faster
df_gstels = df_gstels.sample(10000, random_state=42)

print("Shape of the dataset: ", df_gstels.shape)

Shape of the dataset:  (10000, 20)


### XGStels

In [4]:
fname_xgstels = '../data/XGStels/XGstels.csv'
df_xgstels = pd.read_csv(fname_xgstels)

# Sample 10000 to make it faster
df_xgstels = df_xgstels.sample(10000, random_state=42)

print("Shape of the dataset: ", df_xgstels.shape)

Shape of the dataset:  (10000, 20)


## Load Model

In [5]:
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Create model
model_params = {
    'input_dim': 10,
    'output_dim': 10,
    'num_gaussians': 62
}
model = MDNFullCovariance(**model_params).to(device)

# Load models and mean_stds
models = [ #"../mdn_torch/models/MDNFullCovariance/2024_03_28_11_53_42.pth",
#           "../mdn_torch/models/MDNFullCovariance/2024_03_30_02_40_44.pth",
#           "../mdn_torch/models/MDNFullCovariance/2024_04_02_15_47_52.pth",
#           "../mdn_torch/models/MDNFullCovariance/2024_04_03_10_03_21.pth",
#           "../mdn_torch/models/MDNFullCovariance/2024_04_04_01_37_38.pth",
#           "../mdn_torch/models/MDNFullCovariance/2024_04_04_13_41_19.pth",
            "./models/MDNFullCovariance/2024_04_05_01_01_18.pth"]
mean_stds = [   # "../mdn_torch/models/mean_std.pth",
                #  "../mdn_torch/models/mean_std_2.pth",
                #  "../mdn_torch/models/mean_std_3.pth",
                #  "../mdn_torch/models/mean_std_4.pth",
                #  "../mdn_torch/models/mean_std_5.pth",
                #  "../mdn_torch/models/mean_std_combined.pth",
                "./models/mean_std_5.pth"]

model_path = models[0]
mean_std_path = mean_stds[0]

# Load model
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))

# Load mean_std
mean_std = torch.load(mean_std_path, map_location=torch.device(device))

## Loss for GStels

In [6]:
# Features 
features = df_gstels.values.astype(np.float32)[:, 10:]

# Labels
labels = df_gstels.values.astype(np.float32)[:, 0:10]

In [7]:
# Convert to tensor
features = torch.tensor(features).to(device)
labels = torch.tensor(labels).to(device)

In [8]:
# Normalize features
features_norm = (features - mean_std['mean']) / mean_std['std']

# Normalize labels
labels_norm = (labels - mean_std['mean_labels']) / mean_std['std_labels']

In [9]:
glob_loss_total = 0
glob_loss_axis_length = 0
glob_loss_iota = 0
glob_loss_max_elongation = 0
glob_loss_min_L_grad_B = 0
glob_loss_min_R0 = 0
glob_loss_r_singularity = 0
glob_loss_L_grad_grad_B = 0
glob_loss_B20_variation = 0
glob_loss_beta = 0
glob_loss_DMerc_times_r2 = 0
counted = 0
from tqdm import tqdm
progress_bar = tqdm(range(len(df_gstels)), desc="Calculating loss")

loss_fn = torch.nn.functional.huber_loss

for i in progress_bar:
    model_input = features_norm[i]
    with torch.no_grad():
        # Predict using model
        model_outputs = model.getMixturesSample(model_input.unsqueeze(0), device)

        # Denormalize output
        model_outputs = model_outputs * mean_std["std_labels"].to(device) + mean_std["mean_labels"].to(device)

        # Run Qsc
        model_outputs = model_outputs.cpu().numpy()[0]
        model_outputs = round_nfp(model_outputs)
        
        # Run Qsc
        try: 
            qsc_output = run_qsc(model_outputs)
            # Normalize qsc_values
            qsc_output_normalized = (torch.tensor(qsc_output).to(device) - mean_std["mean"].to(device)) / mean_std["std"].to(device)
            # Loss for each output
            loss_axis_length = loss_fn(model_input[0], qsc_output_normalized[0]).item()
            loss_iota = loss_fn(model_input[1], qsc_output_normalized[1]).item()
            loss_max_elongation = loss_fn(model_input[2], qsc_output_normalized[2]).item()
            loss_min_L_grad_B = loss_fn(model_input[3], qsc_output_normalized[3]).item()
            loss_min_R0 = loss_fn(model_input[4], qsc_output_normalized[4]).item()
            loss_r_singularity = loss_fn(model_input[5], qsc_output_normalized[5]).item()
            loss_L_grad_grad_B = loss_fn(model_input[6], qsc_output_normalized[6]).item()
            loss_B20_variation = loss_fn(model_input[7], qsc_output_normalized[7]).item()
            loss_beta = loss_fn(model_input[8], qsc_output_normalized[8]).item()
            loss_DMerc_times_r2 = loss_fn(model_input[9], qsc_output_normalized[9]).item()
            loss_total = loss_fn(model_input, qsc_output_normalized).item()
            # Add if loss is not inf or nan
            if not np.isnan(loss_total) and not np.isinf(loss_total):
                glob_loss_total += loss_total
                glob_loss_axis_length += loss_axis_length
                glob_loss_iota += loss_iota
                glob_loss_max_elongation += loss_max_elongation
                glob_loss_min_L_grad_B += loss_min_L_grad_B
                glob_loss_min_R0 += loss_min_R0
                glob_loss_r_singularity += loss_r_singularity
                glob_loss_L_grad_grad_B += loss_L_grad_grad_B
                glob_loss_B20_variation += loss_B20_variation
                glob_loss_beta += loss_beta
                glob_loss_DMerc_times_r2 += loss_DMerc_times_r2
                counted += 1
        except:
            continue
    # Update progress bar
    progress_bar.set_postfix(
        {
            "total_loss": glob_loss_total / counted,
            "axis_length": glob_loss_axis_length / counted,
            "iota": glob_loss_iota / counted,
            "max_elongation": glob_loss_max_elongation / counted,
            "min_L_grad_B": glob_loss_min_L_grad_B / counted,
            "min_R0": glob_loss_min_R0 / counted,
            "r_singularity": glob_loss_r_singularity / counted,
            "L_grad_grad_B": glob_loss_L_grad_grad_B / counted,
            "B20_variation": glob_loss_B20_variation / counted,
            "beta": glob_loss_beta / counted,
            "DMerc_times_r2": glob_loss_DMerc_times_r2 / counted,
        }
    )
    progress_bar.update()

Calculating loss: 100%|██████████| 10000/10000 [07:57<00:00, 20.95it/s, total_loss=0.274, axis_length=0.0525, iota=0.0255, max_elongation=0.00359, min_L_grad_B=0.299, min_R0=0.359, r_singularity=0.918, L_grad_grad_B=0.436, B20_variation=4.09e-5, beta=0.65, DMerc_times_r2=1.46e-8]   


## Loss for XGStels

In [10]:
# Features 
features = df_xgstels.values.astype(np.float32)[:, 10:]

# Labels
labels = df_xgstels.values.astype(np.float32)[:, 0:10]

In [11]:
# Convert to tensor
features = torch.tensor(features).to(device)
labels = torch.tensor(labels).to(device)

In [12]:
# Normalize features
features_norm = (features - mean_std['mean']) / mean_std['std']

# Normalize labels
labels_norm = (labels - mean_std['mean_labels']) / mean_std['std_labels']

In [13]:
glob_loss_total = 0
glob_loss_axis_length = 0
glob_loss_iota = 0
glob_loss_max_elongation = 0
glob_loss_min_L_grad_B = 0
glob_loss_min_R0 = 0
glob_loss_r_singularity = 0
glob_loss_L_grad_grad_B = 0
glob_loss_B20_variation = 0
glob_loss_beta = 0
glob_loss_DMerc_times_r2 = 0
counted = 0
from tqdm import tqdm
progress_bar = tqdm(range(len(df_gstels)), desc="Calculating loss")
for i in progress_bar:
    model_input = features_norm[i]
    with torch.no_grad():
        # Predict using model
        model_outputs = model.getMixturesSample(model_input.unsqueeze(0), device)

        # Denormalize output
        model_outputs = model_outputs * mean_std["std_labels"].to(device) + mean_std["mean_labels"].to(device)

        # Run Qsc
        model_outputs = model_outputs.cpu().numpy()[0]
        model_outputs = round_nfp(model_outputs)

        # Loss 
        loss_fn = torch.nn.functional.huber_loss
        
        # Run Qsc
        try: 
            qsc_output = run_qsc(model_outputs)
            # Normalize qsc_values
            qsc_output_normalized = (torch.tensor(qsc_output).to(device) - mean_std["mean"].to(device)) / mean_std["std"].to(device)
            # Loss for each output
            loss_axis_length = loss_fn(model_input[0], qsc_output_normalized[0]).item()
            loss_iota = loss_fn(model_input[1], qsc_output_normalized[1]).item()
            loss_max_elongation = loss_fn(model_input[2], qsc_output_normalized[2]).item()
            loss_min_L_grad_B = loss_fn(model_input[3], qsc_output_normalized[3]).item()
            loss_min_R0 = loss_fn(model_input[4], qsc_output_normalized[4]).item()
            loss_r_singularity = loss_fn(model_input[5], qsc_output_normalized[5]).item()
            loss_L_grad_grad_B = loss_fn(model_input[6], qsc_output_normalized[6]).item()
            loss_B20_variation = loss_fn(model_input[7], qsc_output_normalized[7]).item()
            loss_beta = loss_fn(model_input[8], qsc_output_normalized[8]).item()
            loss_DMerc_times_r2 = loss_fn(model_input[9], qsc_output_normalized[9]).item()
            loss_total = loss_fn(model_input, qsc_output_normalized).item()
            # Add if loss is not inf or nan
            if not np.isnan(loss_total) and not np.isinf(loss_total):
                glob_loss_total += loss_total
                glob_loss_axis_length += loss_axis_length
                glob_loss_iota += loss_iota
                glob_loss_max_elongation += loss_max_elongation
                glob_loss_min_L_grad_B += loss_min_L_grad_B
                glob_loss_min_R0 += loss_min_R0
                glob_loss_r_singularity += loss_r_singularity
                glob_loss_L_grad_grad_B += loss_L_grad_grad_B
                glob_loss_B20_variation += loss_B20_variation
                glob_loss_beta += loss_beta
                glob_loss_DMerc_times_r2 += loss_DMerc_times_r2
                counted += 1
        except:
            continue
    # Update progress bar
    progress_bar.set_postfix(
        {
            "total_loss": glob_loss_total / counted,
            "axis_length": glob_loss_axis_length / counted,
            "iota": glob_loss_iota / counted,
            "max_elongation": glob_loss_max_elongation / counted,
            "min_L_grad_B": glob_loss_min_L_grad_B / counted,
            "min_R0": glob_loss_min_R0 / counted,
            "r_singularity": glob_loss_r_singularity / counted,
            "L_grad_grad_B": glob_loss_L_grad_grad_B / counted,
            "B20_variation": glob_loss_B20_variation / counted,
            "beta": glob_loss_beta / counted,
            "DMerc_times_r2": glob_loss_DMerc_times_r2 / counted,
        }
    )
    progress_bar.update()

Calculating loss: 100%|██████████| 10000/10000 [07:33<00:00, 22.05it/s, total_loss=0.208, axis_length=0.0398, iota=0.0274, max_elongation=0.000792, min_L_grad_B=0.342, min_R0=0.366, r_singularity=0.597, L_grad_grad_B=0.442, B20_variation=0.00556, beta=0.264, DMerc_times_r2=0.000128]
