In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import tqdm

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import matplotlib.pyplot as plt
import ml_collections
from ml_collections import config_dict

from models import noise_schedules
from models import transformer, scores, vdm
from models import diffusion_utils

%matplotlib inline

In [2]:
def get_config():

    config = config_dict.ConfigDict()

    config.workdir = 'trained_models/test2'

    config.data = data = config_dict.ConfigDict()
    data.conditioning_parameters = [
        "halo_mvir", "halo_mstar", "center_subhalo_mvir", "center_subhalo_mstar",
        "center_subhalo_vpeculiar", "center_subhalo_vmax_tilde", "log_num_subhalos",
        "inv_wdm_mass", "log_sn1", "log_sn2", "log_agn1"
    ]

    # VDM args
    config.vdm = vdm = config_dict.ConfigDict()
    vdm.d_in = 9
    vdm.d_cond = len(data.conditioning_parameters)
    vdm.d_context_embedding = 16
    vdm.timesteps = 0
    vdm.antithetic_time_sampling = True
    vdm.use_encdec = False
    vdm.embed_context = True
    
    vdm.score_model = score = config_dict.ConfigDict()
    score.name = 'transformer'
    score.d_t_embedding = 16
    score.d_model = 128
    score.d_mlp = 256
    score.d_cond = vdm.d_context_embedding if vdm.embed_context else vdm.d_cond
    score.n_layers = 6
    score.n_heads = 4

    vdm.noise_schedule = noise_schedule = config_dict.ConfigDict()
    noise_schedule.name = "learned_linear"
    noise_schedule.gamma_min = -16.0
    noise_schedule.gamma_max = 10.0

    # training and loss args
    config.training = training = config_dict.ConfigDict()
    training.batch_size = 128
    training.max_steps = 50_000
    training.noise_scale = 1e-3
    training.beta = 1.0
    training.rotation_augmentation = True
    training.n_pos_dim = 3
    training.n_vel_dim = 3
    training.add_mass_recon_loss = True
    training.i_mass_start = 6
    training.i_mass_stop = 8

    # optimizer and scheduler args
    config.optimizer = optimizer = config_dict.ConfigDict()
    optimizer.name = "AdamW"
    optimizer.lr = 5e-4
    optimizer.betas = [0.9, 0.999]
    optimizer.weight_decay = 0.01
    optimizer.grad_clip = 0.5
    config.scheduler = scheduler = config_dict.ConfigDict()
    scheduler.name = "WarmUpCosineDecayLR"
    scheduler.init_value = 0.0
    scheduler.peak_value = optimizer.lr
    scheduler.warmup_steps = 5_000
    scheduler.decay_steps = training.max_steps

    return config

config = get_config()

In [3]:
# read the dataset
datadir = '/mnt/ceph/users/tnguyen/dark_camels/point-cloud-diffusion-datasets/processed_datasets/final-WDM-datasets'

train_data = np.load(os.path.join(datadir, 'train_galprop.npz'))
train_cond_table = pd.read_csv(os.path.join(datadir, 'train_galprop_cond.csv'))
val_data = np.load(os.path.join(datadir, 'val_galprop.npz'))
val_cond_table = pd.read_csv(os.path.join(datadir, 'val_galprop_cond.csv'))

In [4]:
EPS = 1e-6

# preprocess the data
train_x = train_data['features']
train_mask = train_data['mask']
train_cond = train_cond_table[config.data.conditioning_parameters].values
val_x = val_data['features']
val_mask = val_data['mask']
val_cond = val_cond_table[config.data.conditioning_parameters].values

# normalize the data and convert to torch tensors
mask_bool = train_mask.astype(bool)
x_mean = np.mean(train_x, axis=(0, 1), where=mask_bool[..., None])
x_std = np.std(train_x, axis=(0, 1), where=mask_bool[..., None])
cond_mean = np.mean(train_cond, axis=0)
cond_std = np.std(train_cond, axis=0)
norm_dict = {
    'x_mean': x_mean,
    'x_std': x_std,
    'cond_mean': cond_mean,
    'cond_std': cond_std
}

train_x = (train_x - x_mean + EPS) / (x_std + EPS)
train_cond = (train_cond - cond_mean + EPS) / (cond_std + EPS)
val_x = (val_x - x_mean + EPS) / (x_std + EPS)
val_cond = (val_cond - cond_mean + EPS) / (cond_std + EPS)

train_x = torch.tensor(train_x, dtype=torch.float32)
train_cond = torch.tensor(train_cond, dtype=torch.float32)
val_x = torch.tensor(val_x, dtype=torch.float32)
val_cond = torch.tensor(val_cond, dtype=torch.float32)
# invert mask due to torch convention vs jax
train_mask = ~torch.tensor(train_mask, dtype=torch.bool)
val_mask = ~torch.tensor(val_mask, dtype=torch.bool)

# Create dataloaders
train_dset = TensorDataset(train_x, train_cond, train_mask)
val_dset = TensorDataset(val_x, val_cond, val_mask)
train_loader = DataLoader(
    train_dset, batch_size=config.training.batch_size, shuffle=True)
val_loader = DataLoader(
    val_dset, batch_size=config.training.batch_size, shuffle=False)

In [5]:
model = vdm.VariationalDiffusionModel(
    d_in=config.vdm.d_in,
    d_cond=config.vdm.d_cond,
    d_context_embedding=config.vdm.d_context_embedding,
    embed_context=config.vdm.embed_context,
    score_model_args=config.vdm.score_model,
    noise_schedule_args=config.vdm.noise_schedule,
    timesteps=config.vdm.timesteps,
    antithetic_time_sampling=config.vdm.antithetic_time_sampling,
    use_encdec=config.vdm.use_encdec,
    training_args=config.training,
    optimizer_args=config.optimizer,
    scheduler_args=config.scheduler,
    norm_dict=norm_dict
)

callbacks = [
    pl.callbacks.EarlyStopping(
        monitor='val_loss', patience=1000, mode='min', verbose=True),
    pl.callbacks.ModelCheckpoint(
        monitor='val_loss', save_top_k=10, mode='min', save_weights_only=False),
    pl.callbacks.LearningRateMonitor("step"),
]
train_logger = pl_loggers.TensorBoardLogger(config.workdir)

trainer = pl.Trainer(
    default_root_dir=config.workdir,
    max_steps=config.training.max_steps,
    accelerator='gpu',
    callbacks=callbacks,
    logger=train_logger,
    enable_progress_bar=True,
    inference_mode=False,
    gradient_clip_val=config.optimizer.grad_clip,
)

# train the model
trainer.fit(model, train_loader, val_loader)

/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mnt/home/tnguyen/miniconda3/envs/geometric/lib/pyth ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mnt/home/tnguyen/miniconda3/envs/geometric/lib/pyth ...
Missing logger folder: trained_models/test2/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | 

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


                                                                           

/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/mnt/home/tnguyen/miniconda3/envs/geometric/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 7/7 [00:00<00:00, 18.62it/s, v_num=0]

Metric val_loss improved. New best score: 225.326


Epoch 1: 100%|██████████| 7/7 [00:00<00:00, 17.12it/s, v_num=0]

Metric val_loss improved by 17.676 >= min_delta = 0.0. New best score: 207.650


Epoch 3: 100%|██████████| 7/7 [00:00<00:00, 18.32it/s, v_num=0]

Metric val_loss improved by 1.514 >= min_delta = 0.0. New best score: 206.136


Epoch 4: 100%|██████████| 7/7 [00:00<00:00, 17.34it/s, v_num=0]

Metric val_loss improved by 2.299 >= min_delta = 0.0. New best score: 203.837


Epoch 5: 100%|██████████| 7/7 [00:00<00:00, 16.54it/s, v_num=0]

Metric val_loss improved by 6.450 >= min_delta = 0.0. New best score: 197.387


Epoch 6: 100%|██████████| 7/7 [00:00<00:00, 16.66it/s, v_num=0]

Metric val_loss improved by 2.892 >= min_delta = 0.0. New best score: 194.495


Epoch 8: 100%|██████████| 7/7 [00:00<00:00, 16.43it/s, v_num=0]

Metric val_loss improved by 1.255 >= min_delta = 0.0. New best score: 193.240


Epoch 9: 100%|██████████| 7/7 [00:00<00:00, 17.48it/s, v_num=0]

Metric val_loss improved by 4.927 >= min_delta = 0.0. New best score: 188.313


Epoch 14: 100%|██████████| 7/7 [00:00<00:00, 17.55it/s, v_num=0]

Metric val_loss improved by 2.576 >= min_delta = 0.0. New best score: 185.738


Epoch 17: 100%|██████████| 7/7 [00:00<00:00, 17.63it/s, v_num=0]

Metric val_loss improved by 3.102 >= min_delta = 0.0. New best score: 182.635


Epoch 18: 100%|██████████| 7/7 [00:00<00:00, 17.25it/s, v_num=0]

Metric val_loss improved by 4.612 >= min_delta = 0.0. New best score: 178.024


Epoch 31: 100%|██████████| 7/7 [00:00<00:00, 18.76it/s, v_num=0]

Metric val_loss improved by 6.008 >= min_delta = 0.0. New best score: 172.016


Epoch 32: 100%|██████████| 7/7 [00:00<00:00, 19.34it/s, v_num=0]

Metric val_loss improved by 0.638 >= min_delta = 0.0. New best score: 171.378


Epoch 38: 100%|██████████| 7/7 [00:00<00:00, 18.37it/s, v_num=0]

Metric val_loss improved by 1.157 >= min_delta = 0.0. New best score: 170.221


Epoch 39: 100%|██████████| 7/7 [00:00<00:00, 18.61it/s, v_num=0]

Metric val_loss improved by 7.339 >= min_delta = 0.0. New best score: 162.882


Epoch 54: 100%|██████████| 7/7 [00:00<00:00, 15.65it/s, v_num=0]

Metric val_loss improved by 2.648 >= min_delta = 0.0. New best score: 160.234


Epoch 57: 100%|██████████| 7/7 [00:00<00:00, 16.42it/s, v_num=0]

Metric val_loss improved by 2.439 >= min_delta = 0.0. New best score: 157.794


Epoch 66: 100%|██████████| 7/7 [00:00<00:00, 16.99it/s, v_num=0]

Metric val_loss improved by 8.637 >= min_delta = 0.0. New best score: 149.157


Epoch 77: 100%|██████████| 7/7 [00:00<00:00, 17.25it/s, v_num=0]

Metric val_loss improved by 0.480 >= min_delta = 0.0. New best score: 148.677


Epoch 84: 100%|██████████| 7/7 [00:00<00:00, 18.42it/s, v_num=0]

Metric val_loss improved by 3.036 >= min_delta = 0.0. New best score: 145.641


Epoch 91: 100%|██████████| 7/7 [00:00<00:00, 17.05it/s, v_num=0]

Metric val_loss improved by 5.275 >= min_delta = 0.0. New best score: 140.367


Epoch 93: 100%|██████████| 7/7 [00:00<00:00, 17.31it/s, v_num=0]

Metric val_loss improved by 2.492 >= min_delta = 0.0. New best score: 137.875


Epoch 94: 100%|██████████| 7/7 [00:00<00:00, 17.97it/s, v_num=0]

Metric val_loss improved by 0.107 >= min_delta = 0.0. New best score: 137.768


Epoch 95: 100%|██████████| 7/7 [00:00<00:00, 17.51it/s, v_num=0]

Metric val_loss improved by 1.031 >= min_delta = 0.0. New best score: 136.737


Epoch 107: 100%|██████████| 7/7 [00:00<00:00, 17.59it/s, v_num=0]

Metric val_loss improved by 4.112 >= min_delta = 0.0. New best score: 132.625


Epoch 117: 100%|██████████| 7/7 [00:00<00:00, 17.53it/s, v_num=0]

Metric val_loss improved by 1.965 >= min_delta = 0.0. New best score: 130.660


Epoch 120: 100%|██████████| 7/7 [00:00<00:00, 16.04it/s, v_num=0]

Metric val_loss improved by 3.266 >= min_delta = 0.0. New best score: 127.395


Epoch 121: 100%|██████████| 7/7 [00:00<00:00, 16.22it/s, v_num=0]

Metric val_loss improved by 0.415 >= min_delta = 0.0. New best score: 126.979


Epoch 128: 100%|██████████| 7/7 [00:00<00:00, 16.43it/s, v_num=0]

Metric val_loss improved by 2.192 >= min_delta = 0.0. New best score: 124.787


Epoch 129: 100%|██████████| 7/7 [00:00<00:00, 16.09it/s, v_num=0]

Metric val_loss improved by 2.152 >= min_delta = 0.0. New best score: 122.635


Epoch 137: 100%|██████████| 7/7 [00:00<00:00, 16.02it/s, v_num=0]

Metric val_loss improved by 2.879 >= min_delta = 0.0. New best score: 119.756


Epoch 145: 100%|██████████| 7/7 [00:00<00:00, 16.09it/s, v_num=0]

Metric val_loss improved by 5.734 >= min_delta = 0.0. New best score: 114.022


Epoch 156: 100%|██████████| 7/7 [00:00<00:00, 16.20it/s, v_num=0]

Metric val_loss improved by 5.904 >= min_delta = 0.0. New best score: 108.118


Epoch 173: 100%|██████████| 7/7 [00:00<00:00, 16.21it/s, v_num=0]

Metric val_loss improved by 0.461 >= min_delta = 0.0. New best score: 107.657


Epoch 174: 100%|██████████| 7/7 [00:00<00:00, 16.14it/s, v_num=0]

Metric val_loss improved by 2.321 >= min_delta = 0.0. New best score: 105.337


Epoch 178: 100%|██████████| 7/7 [00:00<00:00, 15.95it/s, v_num=0]

Metric val_loss improved by 2.966 >= min_delta = 0.0. New best score: 102.371


Epoch 182: 100%|██████████| 7/7 [00:00<00:00, 16.09it/s, v_num=0]

Metric val_loss improved by 0.145 >= min_delta = 0.0. New best score: 102.226


Epoch 191: 100%|██████████| 7/7 [00:00<00:00, 16.94it/s, v_num=0]

Metric val_loss improved by 1.418 >= min_delta = 0.0. New best score: 100.808


Epoch 194: 100%|██████████| 7/7 [00:00<00:00, 17.20it/s, v_num=0]

Metric val_loss improved by 3.776 >= min_delta = 0.0. New best score: 97.032


Epoch 205: 100%|██████████| 7/7 [00:00<00:00, 16.95it/s, v_num=0]

Metric val_loss improved by 8.557 >= min_delta = 0.0. New best score: 88.475


Epoch 228: 100%|██████████| 7/7 [00:00<00:00, 16.75it/s, v_num=0]

Metric val_loss improved by 0.846 >= min_delta = 0.0. New best score: 87.629


Epoch 243: 100%|██████████| 7/7 [00:00<00:00, 16.35it/s, v_num=0]

Metric val_loss improved by 2.271 >= min_delta = 0.0. New best score: 85.358


Epoch 245: 100%|██████████| 7/7 [00:00<00:00, 16.40it/s, v_num=0]

Metric val_loss improved by 2.649 >= min_delta = 0.0. New best score: 82.708


Epoch 253: 100%|██████████| 7/7 [00:00<00:00, 16.01it/s, v_num=0]

Metric val_loss improved by 0.953 >= min_delta = 0.0. New best score: 81.755


Epoch 255: 100%|██████████| 7/7 [00:00<00:00, 16.11it/s, v_num=0]

Metric val_loss improved by 3.971 >= min_delta = 0.0. New best score: 77.784


Epoch 259: 100%|██████████| 7/7 [00:00<00:00, 15.82it/s, v_num=0]

Metric val_loss improved by 5.896 >= min_delta = 0.0. New best score: 71.888


Epoch 281: 100%|██████████| 7/7 [00:00<00:00, 17.28it/s, v_num=0]

Metric val_loss improved by 0.423 >= min_delta = 0.0. New best score: 71.465


Epoch 298: 100%|██████████| 7/7 [00:00<00:00, 17.23it/s, v_num=0]

Metric val_loss improved by 3.907 >= min_delta = 0.0. New best score: 67.558


Epoch 312: 100%|██████████| 7/7 [00:00<00:00, 17.81it/s, v_num=0]

Metric val_loss improved by 5.531 >= min_delta = 0.0. New best score: 62.028


Epoch 329: 100%|██████████| 7/7 [00:00<00:00, 17.61it/s, v_num=0]

Metric val_loss improved by 0.051 >= min_delta = 0.0. New best score: 61.976


Epoch 330: 100%|██████████| 7/7 [00:00<00:00, 17.02it/s, v_num=0]

Metric val_loss improved by 3.346 >= min_delta = 0.0. New best score: 58.631


Epoch 356: 100%|██████████| 7/7 [00:00<00:00, 15.60it/s, v_num=0]

Metric val_loss improved by 0.266 >= min_delta = 0.0. New best score: 58.365


Epoch 361: 100%|██████████| 7/7 [00:00<00:00, 15.82it/s, v_num=0]

Metric val_loss improved by 0.724 >= min_delta = 0.0. New best score: 57.641


Epoch 363: 100%|██████████| 7/7 [00:00<00:00, 16.01it/s, v_num=0]

Metric val_loss improved by 1.800 >= min_delta = 0.0. New best score: 55.841


Epoch 365: 100%|██████████| 7/7 [00:00<00:00, 17.20it/s, v_num=0]

Metric val_loss improved by 3.355 >= min_delta = 0.0. New best score: 52.486


Epoch 393: 100%|██████████| 7/7 [00:00<00:00, 16.58it/s, v_num=0]

Metric val_loss improved by 1.117 >= min_delta = 0.0. New best score: 51.369


Epoch 400: 100%|██████████| 7/7 [00:00<00:00, 17.48it/s, v_num=0]

Metric val_loss improved by 2.555 >= min_delta = 0.0. New best score: 48.814


Epoch 405: 100%|██████████| 7/7 [00:00<00:00, 16.92it/s, v_num=0]

Metric val_loss improved by 1.872 >= min_delta = 0.0. New best score: 46.942


Epoch 408: 100%|██████████| 7/7 [00:00<00:00, 17.26it/s, v_num=0]

Metric val_loss improved by 0.318 >= min_delta = 0.0. New best score: 46.623


Epoch 417: 100%|██████████| 7/7 [00:00<00:00, 17.47it/s, v_num=0]

Metric val_loss improved by 0.033 >= min_delta = 0.0. New best score: 46.590


Epoch 420: 100%|██████████| 7/7 [00:00<00:00, 17.07it/s, v_num=0]

Metric val_loss improved by 0.031 >= min_delta = 0.0. New best score: 46.559


Epoch 426: 100%|██████████| 7/7 [00:00<00:00, 12.15it/s, v_num=0]

Metric val_loss improved by 0.578 >= min_delta = 0.0. New best score: 45.981


Epoch 434: 100%|██████████| 7/7 [00:00<00:00, 17.40it/s, v_num=0]

Metric val_loss improved by 0.451 >= min_delta = 0.0. New best score: 45.529


Epoch 444: 100%|██████████| 7/7 [00:00<00:00, 17.23it/s, v_num=0]

Metric val_loss improved by 4.688 >= min_delta = 0.0. New best score: 40.841


Epoch 447: 100%|██████████| 7/7 [00:00<00:00, 16.94it/s, v_num=0]

Metric val_loss improved by 1.746 >= min_delta = 0.0. New best score: 39.096


Epoch 465: 100%|██████████| 7/7 [00:00<00:00, 17.23it/s, v_num=0]

Metric val_loss improved by 2.006 >= min_delta = 0.0. New best score: 37.090


Epoch 466: 100%|██████████| 7/7 [00:00<00:00, 16.93it/s, v_num=0]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 37.083


Epoch 496: 100%|██████████| 7/7 [00:00<00:00, 17.92it/s, v_num=0]

Metric val_loss improved by 1.489 >= min_delta = 0.0. New best score: 35.594


Epoch 510: 100%|██████████| 7/7 [00:00<00:00, 17.22it/s, v_num=0]

Metric val_loss improved by 1.545 >= min_delta = 0.0. New best score: 34.049


Epoch 522: 100%|██████████| 7/7 [00:00<00:00, 17.26it/s, v_num=0]

Metric val_loss improved by 4.119 >= min_delta = 0.0. New best score: 29.930


Epoch 561: 100%|██████████| 7/7 [00:00<00:00, 16.51it/s, v_num=0]

Metric val_loss improved by 1.278 >= min_delta = 0.0. New best score: 28.653


Epoch 564: 100%|██████████| 7/7 [00:00<00:00, 16.67it/s, v_num=0]

Metric val_loss improved by 3.125 >= min_delta = 0.0. New best score: 25.527


Epoch 580: 100%|██████████| 7/7 [00:00<00:00, 17.02it/s, v_num=0]

Metric val_loss improved by 3.318 >= min_delta = 0.0. New best score: 22.210


Epoch 605: 100%|██████████| 7/7 [00:00<00:00, 17.89it/s, v_num=0]

Metric val_loss improved by 3.662 >= min_delta = 0.0. New best score: 18.548


Epoch 660: 100%|██████████| 7/7 [00:00<00:00, 18.54it/s, v_num=0]

Metric val_loss improved by 0.444 >= min_delta = 0.0. New best score: 18.103


Epoch 679: 100%|██████████| 7/7 [00:00<00:00, 19.47it/s, v_num=0]

Metric val_loss improved by 0.173 >= min_delta = 0.0. New best score: 17.931


Epoch 700: 100%|██████████| 7/7 [00:00<00:00, 18.25it/s, v_num=0]

Metric val_loss improved by 1.040 >= min_delta = 0.0. New best score: 16.891


Epoch 711: 100%|██████████| 7/7 [00:00<00:00, 18.09it/s, v_num=0]

Metric val_loss improved by 1.930 >= min_delta = 0.0. New best score: 14.961


Epoch 743: 100%|██████████| 7/7 [00:00<00:00, 16.95it/s, v_num=0]

Metric val_loss improved by 0.595 >= min_delta = 0.0. New best score: 14.365


Epoch 784: 100%|██████████| 7/7 [00:00<00:00, 19.13it/s, v_num=0]

Metric val_loss improved by 1.115 >= min_delta = 0.0. New best score: 13.251


Epoch 794: 100%|██████████| 7/7 [00:00<00:00, 19.62it/s, v_num=0]

Metric val_loss improved by 0.334 >= min_delta = 0.0. New best score: 12.917


Epoch 854: 100%|██████████| 7/7 [00:00<00:00, 18.39it/s, v_num=0]

Metric val_loss improved by 1.561 >= min_delta = 0.0. New best score: 11.357


Epoch 859: 100%|██████████| 7/7 [00:00<00:00, 18.04it/s, v_num=0]

Metric val_loss improved by 2.976 >= min_delta = 0.0. New best score: 8.381


Epoch 891: 100%|██████████| 7/7 [00:00<00:00, 18.11it/s, v_num=0]

Metric val_loss improved by 3.110 >= min_delta = 0.0. New best score: 5.271


Epoch 1254: 100%|██████████| 7/7 [00:00<00:00, 18.65it/s, v_num=0]

Metric val_loss improved by 0.114 >= min_delta = 0.0. New best score: 5.157


Epoch 1255: 100%|██████████| 7/7 [00:00<00:00, 19.40it/s, v_num=0]

Metric val_loss improved by 1.036 >= min_delta = 0.0. New best score: 4.121


Epoch 1272: 100%|██████████| 7/7 [00:00<00:00, 18.26it/s, v_num=0]

Metric val_loss improved by 1.105 >= min_delta = 0.0. New best score: 3.017


Epoch 1276: 100%|██████████| 7/7 [00:00<00:00, 19.36it/s, v_num=0]

Metric val_loss improved by 0.317 >= min_delta = 0.0. New best score: 2.699


Epoch 1301: 100%|██████████| 7/7 [00:00<00:00, 18.17it/s, v_num=0]

Metric val_loss improved by 2.905 >= min_delta = 0.0. New best score: -0.206


Epoch 1837: 100%|██████████| 7/7 [00:00<00:00, 19.58it/s, v_num=0]

Metric val_loss improved by 0.432 >= min_delta = 0.0. New best score: -0.638


Epoch 2339: 100%|██████████| 7/7 [00:00<00:00, 32.60it/s, v_num=0]

Metric val_loss improved by 0.130 >= min_delta = 0.0. New best score: -0.768


Epoch 2544: 100%|██████████| 7/7 [00:00<00:00, 32.90it/s, v_num=0]

Metric val_loss improved by 0.188 >= min_delta = 0.0. New best score: -0.956


Epoch 2607: 100%|██████████| 7/7 [00:00<00:00, 33.57it/s, v_num=0]

Metric val_loss improved by 0.883 >= min_delta = 0.0. New best score: -1.839


Epoch 3607: 100%|██████████| 7/7 [00:00<00:00, 18.16it/s, v_num=0]

Monitored metric val_loss did not improve in the last 1000 records. Best score: -1.839. Signaling Trainer to stop.


Epoch 3607: 100%|██████████| 7/7 [00:00<00:00, 17.78it/s, v_num=0]
