# OTUS |  $ p p > t \bar{t} $ 

This notebooks applies OTUS to our second test case: semi-leptonic $t \bar{t}$ decay.

Our physical latent-space is the $e^-$, $\bar{\nu}_e$, $b$, $\bar{b}$, $u$, $\bar{d}$ 4-momentum information produced by the program MadGraph.

Our data-space data is the $e^+$, $MET$, $jet1$, $jet2$, $jet3$, $jet4$ 4-momentum information produced by the program Delphes. The jets are ordered in descending $p_T$.

We arrange this information into 24 dimensional vectors

- Latent space (z): [$p^{\mu}_{e-}$,$p^{\mu}_{\bar{\nu}_e}$,$p^{\mu}_{b}$,$p^{\mu}_{\bar{b}}$,$p^{\mu}_{u}$,$p^{\mu}_{\bar{d}}$]
- Data space (x): [$p^{\mu}_{e^-}$,$p^{\mu}_{MET}$,$p^{\mu}_{jet1}$,$p^{\mu}_{jet2}$,$p^{\mu}_{jet3}$,$p^{\mu}_{jet4}$]

where $p^{\mu}=[p_x, p_y, p_z, E]$ is the 4-momentum of the given particle.

###### Additional Losses and Constraints:
We impose the following additional losses and constraints in this problem.

As in the $p p > Z > e^+ e^-$ test case, we explicitly enforce the Minkowski metric in the output of the networks. Namely, the networks predict the 3-momenta ($\vec{p}$) of the particles. Energy information is then restored using the Minkowski metric: $E^2 = |\vec{p}|^2 + m^2$.

We also explicitly enforce the lower $p_T$ threshold on jets, which requires that $p_T>20$ GeV. Only samples generated by the decoder which pass this threshold are used to calculate losses. This requires modifying the data-space loss term slightly. Additionally, to help with stable traiing, we choose a ResNet architecture for both our encoder and decoder networks.

See the paper for more details: https://arxiv.org/abs/2101.08944.

# Load Required Libraries

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import numpy as np
import os

root_dir = '../../../'

#-- Add utilityFunctions/ to easily use utility .py files --#
import sys
sys.path.append(os.path.join(root_dir, "utilityFunctions/"))

#-- Determine if using GPU or CPU --#
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'  # Set to '-1' to disable GPU
from configs import device, data_dims

print('Using device:', device)

Using device: cuda


# Meta Parameters

In [2]:
data_directory    = os.path.join(root_dir, "data/")
dataset_name      = 'ppttbar'

#-- Set random seeds --#
seed = 2
torch.manual_seed(seed)
np.random.seed(seed)

#-- Raw or standardized inputs/outputs --#
# If True, model inputs/outputs should be in the "raw" (unstandardized) space
raw_io = True  

#-- Set data type --#
from configs import float_type
print('Using data type: ', float_type)

Using data type:  float32


# Load Data

In [3]:
from func_utils import get_dataset, standardize
from torch.utils.data import DataLoader

#-- Get training and validation dataset --#
dataset = get_dataset(dataset_name, data_dir=data_directory)
z_data, x_data = dataset['z_data'], dataset['x_data']
print("Data total shapes: ",z_data.shape, x_data.shape)

x_dim = int(x_data.shape[1])
z_dim = int(z_data.shape[1])

#-- Split into training and validation sets --#
train_size = 222761
val_size = 40000  # Validation set used to evaluate/tune models

x_train = x_data[:train_size, :]
x_val = x_data[train_size:train_size+val_size, :]

z_train = z_data[:train_size, :]
z_val = z_data[train_size:train_size+val_size, :]

#-- Convert data to proper type --#
x_train, x_val, z_train, z_val = list(map(lambda x: x.astype(float_type), [x_train, x_val, z_train, z_val]))

#-- Obtain mean and std information --#
# This is needed to standardize/unstandardize data
x_train_mean, x_train_std = np.mean(x_train, axis=0), np.std(x_train, axis=0)
z_train_mean, z_train_std = np.mean(z_train, axis=0), np.std(z_train, axis=0)

# If raw_io == False, then standardize the data with training set statistics
if not raw_io:
    x_train = (x_train - x_train_mean) / x_train_std  
    x_val = (x_val - x_train_mean) / x_train_std  
if not raw_io:
    z_train = (z_train - z_train_mean) / z_train_std
    z_val = (z_val - z_train_mean) / z_train_std

#-- Set evaluation parameters --#
eval_batch_size = 20000  # Always use high batch size on validation set to accurately assess performance
eval_loaders = DataLoader(dataset=x_val, batch_size=eval_batch_size, shuffle=True), \
               DataLoader(dataset=z_val, batch_size=eval_batch_size, shuffle=True)

print("z_train shape, x_train shape: ", z_train.shape, x_train.shape)
print("z_val   shape, x_val   shape: ", z_val.shape, x_val.shape)

Data total shapes:  (262761, 24) (262761, 24)
z_train shape, x_train shape:  (222761, 24) (222761, 24)
z_val   shape, x_val   shape:  (40000, 24) (40000, 24)


### Define target invariant masses (for both training and validation data)

Invariant mass relation: $m^2 = E^2 - |\vec{p}|^2$. For objects with ill-defined mass (MET and Jets) we fix $m=0$.

In [4]:
x_inv_masses = np.zeros(6)
z_inv_masses = np.array([0., 0., 4.7, 4.7, 0., 0.])

# Train

## Import Training Specific Libraries and Functions

In [5]:
import torch
from torch import optim
import torch.nn as nn
from ppttbar_constraints import threshold_check
from ppttbar_utils import train_and_val

## Define Meta Network Parameters

In [6]:
from models import Autoencoder, StochasticResNet

## Define Model and Hyperparameters

###### Latent loss function:
Finite sample approximation of Sliced Wasserstein Distance (SWD) between $p(z)$ and $p_E(z) = \int_x p(x) p_E(z|x)$

- $L_{latent}(Z, \tilde{Z}) = \frac{1}{L * M} \sum_{l=1}^{L} \sum_{m=1}^{M} c((\theta_l \cdot z_m)_{sorted}, (\theta_l \cdot \tilde{z}_m)_{sorted})$

where $c(\cdot, \cdot) = |\cdot - \cdot|^2$

###### Data loss function:
- $L_{data}(X, \tilde{X}) = \frac{1}{M} \sum_{m=1}^M [\frac{1_S(\tilde{x}_m)}{p_D(S)} 
c(x_m,  \tilde{x}_m)]$

where $c(\cdot, \cdot) = |\cdot - \cdot|^2$; $1_S(x)$ is the indicator function of $S$ so that it equals $1$ if $x \in S$, and $0$ otherwise, and $p_D(S) := \int dt p_D(t) 1_S(t)$ normalizes this distribution.

###### Full loss function:
- $L_{tot} = \beta L_{data}(X, \tilde{X}) + \lambda L_{latent}(Z, \tilde{Z})$ 

###### Core Hyperparameters
The hyperparameter definitions are as follows:

- num_hidden_layers: The number of hidden layers in both the encoder and decoder networks
- dim_per_hidden_layer: The dimensions per hidden layer in both the encoder and decoder networks
- lr: The learning rate of the networks
- lamb: The $\lambda$ coefficient in front of the latent loss term
- num_slices: Number of random projections used for computing SWD
- epochs: The number of epochs used during training

Hyperparameters for other losses that were tried, but use during main training is currently discouraged:

- tau: Coefficient in front of the alternate data-space loss ("alt_x_loss"), which is the SWD between $p(x)$ and $p_D(x):=\int_z p(z) p_D(x|z)$
- rho: Coefficient in front of an additional decoder constraint loss (based on soft-penalty approach to learning hard thresholds/ttbar_constraints)

###### Joint Training Hyperparameters
- beta: Coefficient in front of data loss, $L_{data}$ 
- beta_e: Coefficient in front of the encoder "anchor loss" 
- beta_d: Coefficient in front of the decoder "anchor loss" 

In [7]:
# Note: most of the unspecified hyperparameters are set to 0 by default

# common configs
joint_step_config = {
    'lr': 0.001,
    'beta': 1.,  # coefficient in front of data loss, E[c(x, x reconstructed)], where c is typically the 2-norm.
    'lamb': 20.,  # coefficient in front of latent loss, SWD between p(z) and Q(z):=\int_x p(x) Q(z|x)
    'tau': 0,  # coefficient in front of "alt_x_loss", which is the SWD between p(x) and p_G(x):=\int_z p(z) p_G(x|z);
               # this loss is not part of the original WAE formulation and is not used.
    'rho': 0, # coef in front of decoder constraint loss (based on soft-penalty approach to learning hard thresholds/ttbar_constraints)
    'nu_e': 0,  # coefficient in front of encoder "anchor loss"
    'nu_d': 0,
    'epochs': 1000,
    'log_freq': 100,
}


decoder_finetuning_config = {
    'beta': 0,  # coefficient in front of data loss in (S)WAE objective
    'tau': 1, # coefficient in front of "alt_x_loss", which is the SWD between p(x) and p_D(x)
    'lamb': 0, # disable latent loss
    'rho': 0, # no x_constraint_loss (no longer used)
    'nu_e': 0,  # anchor loss
    'nu_d': 0,
    'lr': 0.0001,  # reduced lr for fine-tuning
    'epochs': 10,
    'log_freq': 1,
}


hidden_layer_dims = [64, 64]
activation = torch.nn.ReLU
from models import Autoencoder, StochasticResNet
model = Autoencoder(x_dim, z_dim, ConditionalModel=StochasticResNet, encoder_hidden_layer_dims=hidden_layer_dims,
                    stoch_enc=True, stoch_dec=True, activation=activation, raw_io=raw_io,
                    x_inv_masses=x_inv_masses, x_stats=np.stack([x_train_mean, x_train_std]),
                    z_inv_masses=z_inv_masses, z_stats=np.stack([z_train_mean, z_train_std]),
                    # ResNet settings:
                    io_residual=True,
                    res_mlp_depth=2
                            )

In [8]:
# Print model 
model

Autoencoder(
  (encoder): StochasticResNet(
    (nn): Sequential(
      (0): Linear(in_features=42, out_features=64, bias=True)
      (1): ResBlock(
        (module): Sequential(
          (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU()
          (5): Linear(in_features=64, out_features=64, bias=True)
        )
      )
      (2): Linear(in_features=64, out_features=18, bias=True)
    )
  )
  (decoder): StochasticResNet(
    (nn): Sequential(
      (0): Linear(in_features=42, out_features=64, bias=True)
      (1): ResBlock(
        (module): Sequential(
          (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): B

In [9]:
train_batch_size = 20000
train_loaders = DataLoader(dataset=x_train, batch_size=train_batch_size, shuffle=True), \
                DataLoader(dataset=z_train, batch_size=train_batch_size, shuffle=True)

# Note that the Z and X data loaders are both shuffled, so (z, x) samples no longer match up (unlike in the source
# data arrays), as is demanded by our unsupervised problem setup.

## Loop through different hyperparameters and train/eval a model for each

In [10]:
save_dir = '.'
verbose = True
lambs = [0.001, 0.01, 0.1, 1, 10, 100, 1000]

for lamb in lambs:
    joint_step_config['lamb'] = lamb
    print(f'Training with lamb={lamb}')

    # Reset seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Create a new model and optimizer
    model = Autoencoder(x_dim, z_dim, ConditionalModel=StochasticResNet, encoder_hidden_layer_dims=hidden_layer_dims,
                    stoch_enc=True, stoch_dec=True, activation=activation, raw_io=raw_io,
                    x_inv_masses=x_inv_masses, x_stats=np.stack([x_train_mean, x_train_std]),
                    z_inv_masses=z_inv_masses, z_stats=np.stack([z_train_mean, z_train_std]),
                    # ResNet settings:
                    io_residual=True,
                    res_mlp_depth=2
                            )
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    history = {}
    # Joint SWAE training
    log_freq = joint_step_config['log_freq']
    eval_losses, history = train_and_val(model, train_loaders, eval_loaders, joint_step_config, optimizer, verbose=verbose, prev_hist=history, log_freq=log_freq)
    
    # Decoder fine-tuning with SWD(p(x), p_D(x)) loss only
    ## Reduce lr
    for g in optimizer.param_groups:
        g['lr'] = decoder_finetuning_config["lr"]

    log_freq = decoder_finetuning_config['log_freq']
    eval_losses, history = train_and_val(model, train_loaders, eval_loaders, decoder_finetuning_config, optimizer, verbose=verbose, prev_hist=history, log_freq=log_freq)
    
    print('Final val losses:', eval_losses)
    
    # Save history in JSON-lines format
    ## convert pytorch float tensors into plain numpy float arrs in history
    for key, val in history.items():
        if isinstance(val, (list, np.ndarray)) :
            if not isinstance(np.sum(val), (int, np.integer)):  # my crude test to see if this is an array of float type
                history[key] = [float(n) for n in val]
    import pandas as pd
    df = pd.DataFrame(history)
    df.to_json(os.path.join(save_dir, f'history-lamb={lamb}.jsonl'), orient='records', lines=True)
    
    
    # Save trained model
    save_path = os.path.join(save_dir, f'swae-lamb={lamb}.pkl')
    torch.save(model.state_dict(), save_path)
    print('Saved model weights to', save_path)
    
    
    # Evaluate and plot results
    from plot_utils import plotFunction
    from ppttbar_constraints import threshold_check
    #-- Reset random seeds --#
    torch.manual_seed(seed)
    np.random.seed(seed)
    #-- Convert model to run on CPU and run on validation data --#
    model.to('cpu')
    model.encoder.output_stats.to('cpu')
    model.decoder.output_stats.to('cpu')

    all_arrs = {'train': {}, 'val': {}}  # This will store all numpy arrays of interest
    all_arrs['train']['x'] = x_train
    all_arrs['train']['z'] = z_train
    all_arrs['val']['x']   = x_val
    all_arrs['val']['z']   = z_val

    for data_key in 'train', 'val':
        arrs = all_arrs[data_key]
        arrs['z_decoded']       = model.decode(torch.from_numpy(arrs['z'])) # p_D(x) = \int_z p(z) p_D(x|z)  "x_pred_truth"
        arrs['x_encoded']       = model.encode(torch.from_numpy(arrs['x'])) # p_E(z) = \int_x p(x) p_E(z|x)  "z_pred"
        arrs['x_reconstructed'] = model.decode(arrs['x_encoded'])           # p_D(y) = \int_x \int_z p(x) p_E(z|x) p_D(y|z) "x_pred"

        # Feed the same z input to the decoder multiple times and study the stochastic of the output
        num_repeats = 100
        num_diff_zs = 100
        arrs['z_rep']         = np.array([np.repeat(arrs['z'][i:i+1], num_repeats, axis=0) for i in range(num_diff_zs)]) # "z_fixed"
        z_rep_tensor          = torch.from_numpy(arrs['z_rep'])                                                          # tmp
        arrs['z_decoded_rep'] = np.array([model.decode(z_rep_tensor[i]).detach().numpy() for i in range(num_diff_zs)])   # "x_pred_truth_fixed"
        arrs['x_rep']         = np.array([np.repeat(arrs['x'][i:i+1], num_repeats, axis=0) for i in range(num_diff_zs)]) # "x_fixed"

        # Convert all results to numpy arrays
        for (field, arr) in arrs.items():
            if isinstance(arr, torch.Tensor):
                arrs[field] = arr.detach().numpy()
                
    # Create new arrays from model output that passes cuts
    ## Just do it on validation set since we only record results on val set
#     for data_key in ['train', 'val']:
    for data_key in ['val']:
        arrs = all_arrs[data_key]
        for field in ('z_decoded', 'x_reconstructed'):
            arr = arrs[field]

            if raw_io:
                arr_raw = arr
            else:
                arr_raw = (arr * x_train_std) + x_train_mean

            # Keep only events that pass threshold constraint
            good_mask = threshold_check(arr_raw)
            if verbose:
                print('Passing rate of', field, good_mask.mean())
            arr_raw = arr_raw[good_mask] 

            if raw_io:
                arr = arr_raw
            else:
                arr = (arr_raw - x_train_mean) / x_train_std

            arrs[field+'_'] = arr

            if field == 'z_decoded':
                arrs['z_decoded_good_mask'] = good_mask
                arrs['z_'] = arrs['z'][good_mask]
            else:
                arrs['x_encoded_'] = arrs['x_encoded'][good_mask]
                
    data_key = 'val'
    arrs = all_arrs[data_key]
    ## Set overall plotting limits
    if raw_io:
        x_display_lims = [
            [(-250, 250), (-250, 250), (-700, 700), (0, 700)],
            [(-250, 250), (-250, 250), (-2900, 2900), (0, 2900)],
            [(-300, 300), (-300, 300), (-700, 700), (0, 700)],
            [(-250, 250), (-250, 250), (-700, 700), (0, 700)],
            [(-200, 200), (-200, 200), (-700, 700), (0, 700)],
            [(-100, 100), (-100, 100), (-700, 700), (0, 700)]
        ]
        z_display_lims = [
            [(-250, 250), (-250, 250), (-700, 700), (0, 700)],
            [(-250, 250), (-250, 250), (-1000, 1000), (0, 1000)], # Only differs from x_display_lims here
            [(-300, 300), (-300, 300), (-700, 700), (0, 700)],
            [(-250, 250), (-250, 250), (-700, 700), (0, 700)],
            [(-200, 200), (-200, 200), (-700, 700), (0, 700)],
            [(-100, 100), (-100, 100), (-700, 700), (0, 700)]
        ]
    else:  # Standardized data
        x_display_lims = z_display_lims = [[(-2.5,2.5),(-2.5,2.5),(-5,5),(-2,5)] for _ in range(6)]
        
    ## Z marginals
    dataList = [arrs['z'], arrs['x_encoded']]
    pltDim   = (6,4)
    numBins  = 50
    binsList = []
    for i in range(pltDim[0]):
        for j in range(pltDim[1]):

            low  = z_display_lims[i][j][0]
            high = z_display_lims[i][j][1]

            binsList.append(np.linspace(low, high, numBins))

    particleNameList = [r'$e^-$', r'$\bar{\nu}_e$', r'$b$', r'$\bar{b}$', r'$u$', r'$\bar{d}$']
    fig = plotFunction(dataList = dataList, pltDim = pltDim, binsList = binsList, particleNameList = particleNameList, show=False)
    fig.savefig(os.path.join(save_dir, f'Z_marginals-lamb={lamb}.png'), bbox_inches='tight')
    plt.close(fig)
    
    ## X marginals
    dataList = [arrs['x'], arrs['x_reconstructed_'], arrs['z_decoded_']] # Use only passing events
    pltDim   = (6,4)
    numBins  = 50
    binsList = []
    for i in range(pltDim[0]):
        for j in range(pltDim[1]):

            low  = x_display_lims[i][j][0]
            high = x_display_lims[i][j][1]

            binsList.append(np.linspace(low, high, numBins))

    particleNameList = [r'$e^-$', r'$MET$', r'$Jet1$', r'$Jet2$', r'$Jet3$', r'$Jet4$']
    fig = plotFunction(dataList = dataList, pltDim = pltDim, binsList = binsList, particleNameList = particleNameList, show=False)
    fig.savefig(os.path.join(save_dir, f'X_marginals-lamb={lamb}.png'), bbox_inches='tight')
    plt.close(fig)

    # Evaluate derived quantities
    from top_masses import ttbar_masses   
    for field in 'x', 'z_decoded', 'x_reconstructed':

        # Check if this has been calculated for x already 
        if field == 'x' and isinstance(arrs.get(field+'_masses'), np.ndarray): 
            continue

        arr = arrs[field]

        if raw_io:
            arr_raw = arr
        else:
            arr_raw = (arr * x_train_std) + x_train_mean

        # Keep only events that pass threshold constraint
        good_mask = threshold_check(arr_raw)
        if verbose:
            print('Passing rate of', field, good_mask.mean())
        arr_raw = arr_raw[good_mask]

        # Calculate invariant mass derived quantities
        masses = []
        mass_keys = 'mttbar', 'mwlep', 'mwhad', 'mtoplep', 'mtophad'
        for x in arr_raw:
            mttbar, mwlep, mwhad, mtoplep, mtophad = ttbar_masses(list(x))
            masses.append(np.array([mttbar, mwlep, mwhad, mtoplep, mtophad]))
        masses = np.asarray(masses)
        arrs[field+'_masses'] = masses
    # Plot derived quantities
    dataList = [arrs['x_masses'], arrs['x_reconstructed_masses'],arrs['z_decoded_masses']] 
    pltDim   = (2,4)
    numBins  = 50
    mass_lims = [(0,1500), (78, 175), (0,250), (50,350), (50,800)]
    binsList = []
    for i in range(len(mass_lims)):
        low  = mass_lims[i][0]
        high = mass_lims[i][1]

        binsList.append(np.linspace(low, high, numBins))
    particleNameList = []
    nameList = [r'$M_{t \bar{t}}$', r'$M_{W, leptonic}$', r'$M_{W, hadronic}$', r'$M_{t, leptonic}$', r'$M_{t, hadronic}$']

    # Create plot
    fig = plotFunction(dataList = dataList, pltDim = pltDim, binsList=binsList, particleNameList=particleNameList, nameList=nameList, show=False)
    fig.savefig(os.path.join(save_dir, f'X_derived-lamb={lamb}.png'), bbox_inches='tight')
    plt.close(fig)
    
    # Transport plots
    from plot_utils import fullTransportPlot
    nzList    = np.repeat(20,24).tolist()
    nxList    = nzList
    # z_display_lims and x_display_lims defined above
    limzList = sum(z_display_lims, [])
    limxList  = sum(x_display_lims, [])
    pltDim    = (6,4)
    titleList = [r'$p_x$',r'$p_y$',r'$p_z$',r'$E$','','','','','','','','','','','','','','','','','','','','']
    fig = fullTransportPlot(arrs['z_'], arrs['z_decoded_'], nzList=nzList, nxList=nxList, limzList=limzList, limxList=limxList, pltDim=pltDim, titleList=titleList, show=False)
    fig.savefig(os.path.join(save_dir, f'z-z_decoded-transport-lamb={lamb}.png'), bbox_inches='tight')
    plt.close(fig)
    fig = fullTransportPlot(arrs['x_encoded_'], arrs['x_reconstructed_'], nzList=nzList, nxList=nxList, limzList=limzList, limxList=limxList, pltDim=pltDim, titleList=titleList, show=False)
    fig.savefig(os.path.join(save_dir, f'x_encoded-x_reconstructed-transport-lamb={lamb}.png'), bbox_inches='tight')
    plt.close(fig)
    
    print()
    print()

Training with lamb=0.001
{'lr': 0.001, 'beta': 1.0, 'lamb': 0.001, 'tau': 0, 'rho': 0, 'nu_e': 0, 'nu_d': 0, 'epochs': 1000, 'log_freq': 100}
epoch:	0
train -- loss:5503.24, prate:0.681, x_loss:5465.75, z_loss:37496.6, anchor_loss:0
eval -- loss:78743.2, prate:0.671, x_loss:4721.69, z_loss:40052.5, alt_x_loss:38690.7, anchor_loss:0.442782
epoch:	100
train -- loss:32.2845, prate:0.681, x_loss:3.19206, z_loss:29092.5, anchor_loss:0
eval -- loss:69477.3, prate:0.672, x_loss:7.37862, z_loss:28035.2, alt_x_loss:41442.1, anchor_loss:0.0736043
epoch:	200
train -- loss:10.2287, prate:0.664, x_loss:0.808117, z_loss:9420.58, anchor_loss:0
eval -- loss:34573.7, prate:0.681, x_loss:1.2351, z_loss:8471.29, alt_x_loss:26102.4, anchor_loss:0.0695682
epoch:	300
train -- loss:5.06961, prate:0.684, x_loss:1.46862, z_loss:3600.99, anchor_loss:0
eval -- loss:20625.1, prate:0.692, x_loss:1.49353, z_loss:3461.33, alt_x_loss:17163.8, anchor_loss:0.173854
epoch:	400
train -- loss:2.77892, prate:0.69, x_loss:0

eval -- loss:505.069, prate:0.684, x_loss:42609.9, z_loss:151.609, alt_x_loss:353.46, anchor_loss:0.141293
epoch:	6
train -- loss:1018.81, prate:0.67, x_loss:0, z_loss:0, anchor_loss:0
eval -- loss:462.57, prate:0.679, x_loss:41429.5, z_loss:163.64, alt_x_loss:298.93, anchor_loss:0.145193
epoch:	7
train -- loss:789.46, prate:0.675, x_loss:0, z_loss:0, anchor_loss:0
eval -- loss:417.667, prate:0.677, x_loss:41924.4, z_loss:165.167, alt_x_loss:252.5, anchor_loss:0.148406
epoch:	8
train -- loss:645.792, prate:0.676, x_loss:0, z_loss:0, anchor_loss:0
eval -- loss:431.832, prate:0.675, x_loss:42712.3, z_loss:156.565, alt_x_loss:275.267, anchor_loss:0.152875
epoch:	9
train -- loss:1042.28, prate:0.671, x_loss:0, z_loss:0, anchor_loss:0
eval -- loss:402.32, prate:0.67, x_loss:43605.6, z_loss:158.716, alt_x_loss:243.605, anchor_loss:0.155959
Final val losses: {'passing_rate': 0.6702749729156494, 'x_loss': 43605.61328125, 'z_loss': 158.71571350097656, 'alt_x_loss': 243.60467529296875, 'encoder_

epoch:	999
train -- loss:178.808, prate:0.788, x_loss:4.72284, z_loss:174.085, anchor_loss:0
eval -- loss:1403.62, prate:0.794, x_loss:8.01497, z_loss:46.209, alt_x_loss:1357.41, anchor_loss:0.148629
{'beta': 0, 'tau': 1, 'lamb': 0, 'rho': 0, 'nu_e': 0, 'nu_d': 0, 'lr': 0.0001, 'epochs': 10, 'log_freq': 1}
epoch:	0
train -- loss:1984.42, prate:0.788, x_loss:0, z_loss:0, anchor_loss:0
eval -- loss:796.508, prate:0.785, x_loss:648.922, z_loss:47.4478, alt_x_loss:749.06, anchor_loss:0.152284
epoch:	1
train -- loss:1197.99, prate:0.767, x_loss:0, z_loss:0, anchor_loss:0
eval -- loss:564.798, prate:0.771, x_loss:795.738, z_loss:48.3671, alt_x_loss:516.431, anchor_loss:0.162006
epoch:	2
train -- loss:1552.72, prate:0.763, x_loss:0, z_loss:0, anchor_loss:0
eval -- loss:437.328, prate:0.76, x_loss:1328.49, z_loss:47.9223, alt_x_loss:389.406, anchor_loss:0.162108
epoch:	3
train -- loss:819.026, prate:0.755, x_loss:0, z_loss:0, anchor_loss:0
eval -- loss:357.908, prate:0.751, x_loss:1652.66, z_l

eval -- loss:66880.4, prate:0.643, x_loss:3078.57, z_loss:57.831, alt_x_loss:66822.6, anchor_loss:0.782034
epoch:	400
train -- loss:19674.3, prate:0.647, x_loss:227.626, z_loss:194.467, anchor_loss:0
eval -- loss:69060.6, prate:0.641, x_loss:3796.38, z_loss:49.9916, alt_x_loss:69010.6, anchor_loss:0.786674
epoch:	500
train -- loss:19083.7, prate:0.668, x_loss:168.247, z_loss:189.154, anchor_loss:0
eval -- loss:63618.3, prate:0.654, x_loss:4984.55, z_loss:52.5487, alt_x_loss:63565.7, anchor_loss:0.779341
epoch:	600
train -- loss:19005.6, prate:0.693, x_loss:206.623, z_loss:187.989, anchor_loss:0
eval -- loss:67263.1, prate:0.667, x_loss:8450.16, z_loss:43.1745, alt_x_loss:67219.9, anchor_loss:0.761896
epoch:	700
train -- loss:13686.3, prate:0.685, x_loss:151.279, z_loss:135.351, anchor_loss:0
eval -- loss:58315.9, prate:0.674, x_loss:8199.62, z_loss:48.4401, alt_x_loss:58267.5, anchor_loss:0.711092
epoch:	800
train -- loss:12907.3, prate:0.692, x_loss:166.788, z_loss:127.405, anchor_los

Passing rate of z_decoded 0.590725
Passing rate of x_reconstructed 0.689375
Passing rate of x 0.99145
Passing rate of z_decoded 0.590725
Passing rate of x_reconstructed 0.689375


