In [1]:
# Variational Autoencoders, Hierarchical VAEs, and Diffusion Models
# Based on Calvin Luo's "Understanding Diffusion Models: A Unified Perspective"

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import load_digits
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# MNIST is optional if you want to test with a larger dataset
USE_MNIST = False  # Set to True to use MNIST instead of digits

In [2]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_categorical(x, p, num_classes=256, reduction=None, dim=None):
    """Log probability of a categorical distribution"""
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_bernoulli(x, p, reduction=None, dim=None):
    """Log probability of a Bernoulli distribution"""
    pp = torch.clamp(p, EPS, 1. - EPS)
    log_p = x * torch.log(pp) + (1. - x) * torch.log(1. - pp)
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    """Log probability of a diagonal Gaussian distribution"""
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_standard_normal(x, reduction=None, dim=None):
    """Log probability of a standard normal distribution"""
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * x**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def kl_normal(mu_q, log_var_q, mu_p=None, log_var_p=None):
    """KL divergence between two Gaussian distributions"""
    if mu_p is None and log_var_p is None:
        # Against standard normal
        kl = -0.5 * (1 + log_var_q - mu_q.pow(2) - log_var_q.exp())
    else:
        # Against another Gaussian
        var_p = torch.exp(log_var_p)
        var_q = torch.exp(log_var_q)
        kl = 0.5 * (log_var_p - log_var_q + (var_q + (mu_q - mu_p).pow(2)) / var_p - 1)
    return kl.sum(1)

In [3]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""
    def __init__(self, mode='train', transforms=None, normalize=True):
        digits = load_digits()
        data = digits.data.astype(np.float32)
        
        # Normalize data to [0, 1] if requested
        if normalize:
            data = data / 16.0
            
        if mode == 'train':
            self.data = data[:1000]
        elif mode == 'val':
            self.data = data[1000:1350]
        else:
            self.data = data[1350:]

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

def get_dataloaders(batch_size=64, use_mnist=False):
    """Create and return dataloaders for the selected dataset"""
    if use_mnist:
        # Create MNIST dataloaders
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        
        train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
        test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)
        
        # Split training data into train and validation
        train_size = int(0.8 * len(train_data))
        val_size = len(train_data) - train_size
        train_data, val_data = torch.utils.data.random_split(train_data, [train_size, val_size])
        
        training_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
        
        # For MNIST, we need to process the data to extract just the images
        process_batch = lambda batch: batch[0].view(batch[0].size(0), -1)
        data_shape = (1, 28, 28)
        
    else:
        # Create Digits dataloaders
        train_data = Digits(mode='train')
        val_data = Digits(mode='val')
        test_data = Digits(mode='test')

        training_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
        
        # For Digits, the data is already in the right format
        process_batch = lambda batch: batch
        data_shape = (8, 8)
    
    return training_loader, val_loader, test_loader, process_batch, data_shape

In [4]:
class Encoder(nn.Module):
    """VAE Encoder - maps input to latent distribution parameters"""
    def __init__(self, encoder_net):
        super(Encoder, self).__init__()
        self.encoder = encoder_net

    @staticmethod
    def reparameterization(mu, log_var):
        """Reparameterization trick for sampling from a Gaussian"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def encode(self, x):
        """Encode input to distribution parameters"""
        h_e = self.encoder(x)
        mu_e, log_var_e = torch.chunk(h_e, 2, dim=1)
        return mu_e, log_var_e

    def sample(self, x=None, mu_e=None, log_var_e=None):
        """Sample from the distribution using the reparameterization trick"""
        if (mu_e is None) and (log_var_e is None):
            mu_e, log_var_e = self.encode(x)
        else:
            if (mu_e is None) or (log_var_e is None):
                raise ValueError('mu and log-var can`t be None!')
        z = self.reparameterization(mu_e, log_var_e)
        return z

    def log_prob(self, x=None, mu_e=None, log_var_e=None, z=None):
        """Compute log probability of a sample under the encoded distribution"""
        if x is not None:
            mu_e, log_var_e = self.encode(x)
            z = self.sample(mu_e=mu_e, log_var_e=log_var_e)
        else:
            if (mu_e is None) or (log_var_e is None) or (z is None):
                raise ValueError('mu, log-var and z can`t be None!')

        return log_normal_diag(z, mu_e, log_var_e)

    def forward(self, x, type='log_prob'):
        assert type in ['encode', 'log_prob'], 'Type could be either encode or log_prob'
        if type == 'log_prob':
            return self.log_prob(x)
        else:
            return self.sample(x)


class Prior(nn.Module):
    """Standard Normal prior distribution for VAE"""
    def __init__(self, latent_dim, device=None):
        super(Prior, self).__init__()
        self.latent_dim = latent_dim
        # Register a dummy parameter to get the device
        self.dummy = nn.Parameter(torch.zeros(1))
        self.device = device

    def sample(self, batch_size):
        """Sample from the prior"""
        if self.device is None:
            device = self.dummy.device
        else:
            device = self.device
        z = torch.randn((batch_size, self.latent_dim), device=device)
        return z

    def log_prob(self, z):
        """Compute log probability under the prior"""
        return log_standard_normal(z)

In [5]:
class Decoder(nn.Module):
    """VAE Decoder - maps latent to output distribution parameters"""
    def __init__(self, decoder_net, distribution='categorical', num_vals=None):
        super(Decoder, self).__init__()
        self.decoder = decoder_net
        self.distribution = distribution
        self.num_vals = num_vals

    def decode(self, z):
        """Decode latent to distribution parameters"""
        h_d = self.decoder(z)

        if self.distribution == 'categorical':
            b = h_d.shape[0]
            d = h_d.shape[1] // self.num_vals
            h_d = h_d.view(b, d, self.num_vals)
            mu_d = torch.softmax(h_d, 2)
            return [mu_d]

        elif self.distribution == 'bernoulli':
            mu_d = torch.sigmoid(h_d)
            return [mu_d]
        
        elif self.distribution == 'gaussian':
            # For continuous data
            mu_d, log_var_d = torch.chunk(h_d, 2, dim=1)
            return [mu_d, log_var_d]
        
        else:
            raise ValueError('Distribution must be `categorical`, `bernoulli`, or `gaussian`')

    def sample(self, z):
        """Sample from the decoded distribution"""
        outs = self.decode(z)

        if self.distribution == 'categorical':
            mu_d = outs[0]
            b = mu_d.shape[0]
            m = mu_d.shape[1]
            mu_d = mu_d.view(mu_d.shape[0], -1, self.num_vals)
            p = mu_d.view(-1, self.num_vals)
            x_new = torch.multinomial(p, num_samples=1).view(b, m)

        elif self.distribution == 'bernoulli':
            mu_d = outs[0]
            x_new = torch.bernoulli(mu_d)
            
        elif self.distribution == 'gaussian':
            mu_d, log_var_d = outs
            std = torch.exp(0.5 * log_var_d)
            eps = torch.randn_like(std)
            x_new = mu_d + std * eps
        
        else:
            raise ValueError('Distribution must be `categorical`, `bernoulli`, or `gaussian`')

        return x_new

    def log_prob(self, x, z):
        """Compute log probability of data under the decoded distribution"""
        outs = self.decode(z)

        if self.distribution == 'categorical':
            mu_d = outs[0]
            log_p = log_categorical(x, mu_d, num_classes=self.num_vals, reduction='sum', dim=-1).sum(-1)
            
        elif self.distribution == 'bernoulli':
            mu_d = outs[0]
            log_p = log_bernoulli(x, mu_d, reduction='sum', dim=-1)
            
        elif self.distribution == 'gaussian':
            mu_d, log_var_d = outs
            log_p = log_normal_diag(x, mu_d, log_var_d, reduction='sum', dim=-1)
            
        else:
            raise ValueError('Distribution must be `categorical`, `bernoulli`, or `gaussian`')

        return log_p

    def forward(self, z, x=None, type='log_prob'):
        assert type in ['decode', 'log_prob'], 'Type could be either decode or log_prob'
        if type == 'log_prob':
            return self.log_prob(x, z)
        else:
            return self.sample(z)

In [None]:
# Update the VAE class initialization
class VAE(nn.Module):
    """Variational Autoencoder"""
    def __init__(self, encoder_net, decoder_net, latent_dim=16, 
                 distribution_type='categorical', num_vals=256, 
                 beta=1.0, name="vae", device=None):
        super(VAE, self).__init__()
        
        self.name = name
        self.latent_dim = latent_dim
        self.beta = beta  # For beta-VAE variants
        self.num_vals = num_vals
        self.distribution_type = distribution_type
        
        self.encoder = Encoder(encoder_net=encoder_net)
        self.decoder = Decoder(distribution=distribution_type, 
                              decoder_net=decoder_net, 
                              num_vals=num_vals)
        self.prior = Prior(latent_dim=latent_dim, device=device)

    def forward(self, x, reduction='avg'):
        """Forward pass computing the ELBO loss"""
        # Encoder
        mu_e, log_var_e = self.encoder.encode(x)
        z = self.encoder.sample(mu_e=mu_e, log_var_e=log_var_e)

        # Compute ELBO components
        reconstruction_loss = self.decoder.log_prob(x, z)
        kl_divergence = kl_normal(mu_e, log_var_e)
        
        # Compute ELBO (negative for minimization)
        elbo = reconstruction_loss - self.beta * kl_divergence

        if reduction == 'sum':
            return -elbo.sum()
        else:
            return -elbo.mean()

    def encode(self, x):
        """Encode input to latent distribution parameters"""
        return self.encoder.encode(x)
        
    def decode(self, z):
        """Decode latent to output distribution parameters"""
        return self.decoder.decode(z)
    
    def reconstruct(self, x):
        """Reconstruct input by encoding and decoding"""
        mu_e, log_var_e = self.encoder.encode(x)
        z = self.encoder.sample(mu_e=mu_e, log_var_e=log_var_e)
        return self.decoder.sample(z)
        
    def sample(self, batch_size=64):
        """Sample from the model by sampling from prior and decoding"""
        device = next(self.parameters()).device
        z = torch.randn((batch_size, self.latent_dim), device=device)  # Sample directly with the correct device
        return self.decoder.sample(z)
        
    def latent_traversal(self, x, dim=0, num_steps=10, range_=(-3, 3)):
        """Generate latent traversal for visualization"""
        mu_e, log_var_e = self.encoder.encode(x.unsqueeze(0))
        z = mu_e.repeat(num_steps, 1)
        
        # Create linspace for the chosen dimension
        values = torch.linspace(range_[0], range_[1], num_steps)
        z[:, dim] = values
        
        return self.decoder.sample(z)

In [8]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Make sure the device is properly set and used throughout the code
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, 
                max_patience=20, process_batch=lambda x: x):
    """Train a model with early stopping"""
    device = next(model.parameters()).device
    history = {'train_loss': [], 'val_loss': [], 'best_epoch': 0}
    best_val_loss = float('inf')
    patience = 0
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        num_batches = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            inputs = process_batch(batch).to(device)
            
            optimizer.zero_grad()
            loss = model(inputs)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            num_batches += 1
            
        train_loss /= num_batches
        history['train_loss'].append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                inputs = process_batch(batch).to(device)
                loss = model(inputs)
                
                val_loss += loss.item()
                num_batches += 1
                
        val_loss /= num_batches
        history['val_loss'].append(val_loss)
        
        # Print progress
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            history['best_epoch'] = epoch
            patience = 0
            
            # Save the model
            os.makedirs("models", exist_ok=True)
            torch.save(model.state_dict(), f"models/{model.name}_best.pt")
            
            # Generate and save sample reconstructions and generations
            visualize_results(model, val_loader, epoch, process_batch)
        else:
            patience += 1
            
        if patience >= max_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
            
    return history

def visualize_results(model, data_loader, epoch, process_batch=lambda x: x, num_samples=16):
    """Generate visualizations of model reconstructions and samples"""
    device = next(model.parameters()).device
    model.eval()
    
    # Create directories if they don't exist
    os.makedirs(f"results/{model.name}/reconstructions", exist_ok=True)
    os.makedirs(f"results/{model.name}/samples", exist_ok=True)
    
    # Get a batch of data
    with torch.no_grad():
        batch = next(iter(data_loader))
        inputs = process_batch(batch).to(device)[:num_samples]
        
        # Generate reconstructions
        if hasattr(model, 'reconstruct'):
            reconstructions = model.reconstruct(inputs)
        else:
            # For diffusion models, we don't have direct reconstructions
            reconstructions = None
            
        # Generate samples
        samples = model.sample(batch_size=num_samples).to(device)
        
        # Generate latent traversals if appropriate
        if hasattr(model, 'latent_traversal'):
            traversal_input = inputs[0] 
            traversals = [model.latent_traversal(traversal_input, dim=i) for i in range(min(4, model.latent_dim))]
        else:
            traversals = None
            
    # Plot results
    plot_results(inputs.cpu(), 
                 reconstructions.cpu() if reconstructions is not None else None, 
                 samples.cpu(), 
                 [t.cpu() for t in traversals] if traversals else None, 
                 model.name, epoch)

def plot_results(inputs, reconstructions, samples, traversals=None, model_name="model", epoch=0):
    """Plot and save visualizations"""
    fig = plt.figure(figsize=(12, 10))
    
    # Determine grid size
    rows = 3 if traversals else 2
    cols = min(8, len(inputs))
    
    # Plot original inputs
    for i in range(cols):
        ax = fig.add_subplot(rows, cols, i + 1)
        img = inputs[i].cpu().numpy()
        
        # Reshape to image dimensions if needed
        if len(img.shape) == 1:
            if img.shape[0] == 64:  # Digits
                img = img.reshape(8, 8)
            elif img.shape[0] == 784:  # MNIST
                img = img.reshape(28, 28)
                
        ax.imshow(img, cmap='gray')
        ax.axis('off')
        if i == 0:
            ax.set_title("Original")
            
    # Plot reconstructions
    if reconstructions is not None:
        for i in range(cols):
            ax = fig.add_subplot(rows, cols, cols + i + 1)
            img = reconstructions[i].cpu().numpy()
            
            # Reshape to image dimensions if needed
            if len(img.shape) == 1:
                if img.shape[0] == 64:  # Digits
                    img = img.reshape(8, 8)
                elif img.shape[0] == 784:  # MNIST
                    img = img.reshape(28, 28)
                    
            ax.imshow(img, cmap='gray')
            ax.axis('off')
            if i == 0:
                ax.set_title("Reconstruction")
                
    # Plot samples
    for i in range(cols):
        row_offset = 2 if reconstructions is not None else 1
        ax = fig.add_subplot(rows, cols, row_offset * cols + i + 1)
        img = samples[i].cpu().numpy()
        
        # Reshape to image dimensions if needed
        if len(img.shape) == 1:
            if img.shape[0] == 64:  # Digits
                img = img.reshape(8, 8)
            elif img.shape[0] == 784:  # MNIST
                img = img.reshape(28, 28)
                
        ax.imshow(img, cmap='gray')
        ax.axis('off')
        if i == 0:
            ax.set_title("Generated")
            
    # Plot latent traversals if available
    if traversals:
        fig2 = plt.figure(figsize=(12, 10))
        num_dims = len(traversals)
        num_steps = traversals[0].shape[0]
        
        for dim in range(num_dims):
            for step in range(num_steps):
                ax = fig2.add_subplot(num_dims, num_steps, dim * num_steps + step + 1)
                img = traversals[dim][step].cpu().numpy()
                
                # Reshape to image dimensions if needed
                if len(img.shape) == 1:
                    if img.shape[0] == 64:  # Digits
                        img = img.reshape(8, 8)
                    elif img.shape[0] == 784:  # MNIST
                        img = img.reshape(28, 28)
                        
                ax.imshow(img, cmap='gray')
                ax.axis('off')
                if step == 0 and dim == 0:
                    ax.set_title("Latent Traversals")
                    
        plt.tight_layout()
        plt.savefig(f"results/{model_name}/traversals/traversal_epoch_{epoch}.png")
        
    plt.tight_layout()
    plt.savefig(f"results/{model_name}/reconstructions/recon_epoch_{epoch}.png")
    plt.close('all')

def plot_training_history(history, model_name="model"):
    """Plot training and validation loss curves"""
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(history['train_loss']) + 1)
    
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    plt.axvline(x=history['best_epoch'] + 1, color='g', linestyle='--', label='Best Model')
    
    plt.title(f'Training and Validation Loss - {model_name}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Save the figure
    os.makedirs(f"results/{model_name}", exist_ok=True)
    plt.savefig(f"results/{model_name}/training_history.png")
    plt.close()

Using device: cuda


In [9]:
# Create directories for models and results
os.makedirs("models", exist_ok=True)
os.makedirs("results", exist_ok=True)



# Get dataloaders
batch_size = 64
train_loader, val_loader, test_loader, process_batch, data_shape = get_dataloaders(
    batch_size=batch_size, use_mnist=USE_MNIST)

# Determine input dimension
if USE_MNIST:
    input_dim = 28 * 28  # MNIST
    input_shape = (1, 28, 28)
else:
    input_dim = 8 * 8    # Digits
    input_shape = (1, 8, 8)

# Hyperparameters
latent_dim = 16
hidden_dim = 256
learning_rate = 1e-3
num_epochs = 100
max_patience = 20

# Likelihood type based on dataset
likelihood_type = 'bernoulli'  # For normalized [0, 1] data
if not USE_MNIST:
    num_vals = 17  # For Digits dataset if using categorical
else:
    num_vals = 256  # For MNIST if using categorical

In [11]:
# Create MLP Encoder and Decoder for VAE
encoder_mlp = nn.Sequential(
    nn.Linear(input_dim, hidden_dim), 
    nn.LeakyReLU(),
    nn.Linear(hidden_dim, hidden_dim), 
    nn.LeakyReLU(),
    nn.Linear(hidden_dim, 2 * latent_dim)  # Mean and log_var
)

decoder_mlp = nn.Sequential(
    nn.Linear(latent_dim, hidden_dim), 
    nn.LeakyReLU(),
    nn.Linear(hidden_dim, hidden_dim), 
    nn.LeakyReLU(),
    nn.Linear(hidden_dim, input_dim if likelihood_type == 'bernoulli' else num_vals * input_dim)
)

# Create VAE model
vae_model = VAE(
    encoder_net=encoder_mlp,
    decoder_net=decoder_mlp,
    latent_dim=latent_dim,
    distribution_type=likelihood_type,
    num_vals=num_vals,
    name="vae_mlp",
    device=device
).to(device)

# Create optimizer
vae_optimizer = torch.optim.Adam(vae_model.parameters(), lr=learning_rate)

# Train VAE
print("Training basic VAE...")
vae_history = train_model(
    vae_model, vae_optimizer, train_loader, val_loader, 
    num_epochs=num_epochs, max_patience=max_patience,
    process_batch=process_batch
)

# Plot training history
plot_training_history(vae_history, model_name=vae_model.name)

Training basic VAE...


Epoch 1/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 1/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 1/100 - Train Loss: 35.8191, Val Loss: 29.4267


Epoch 2/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 2/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 2/100 - Train Loss: 27.9839, Val Loss: 28.1213


Epoch 3/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 3/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 3/100 - Train Loss: 27.4126, Val Loss: 27.7740


Epoch 4/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 4/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 4/100 - Train Loss: 27.3039, Val Loss: 27.7874


Epoch 5/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 5/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 5/100 - Train Loss: 27.2621, Val Loss: 27.7434


Epoch 6/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 6/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 6/100 - Train Loss: 27.2814, Val Loss: 27.6268


Epoch 7/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 7/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 7/100 - Train Loss: 27.2362, Val Loss: 27.6580


Epoch 8/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 8/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 8/100 - Train Loss: 27.1890, Val Loss: 27.6543


Epoch 9/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 9/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 9/100 - Train Loss: 27.1374, Val Loss: 27.7968


Epoch 10/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 10/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 10/100 - Train Loss: 27.0966, Val Loss: 27.4830


Epoch 11/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 11/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 11/100 - Train Loss: 26.8901, Val Loss: 27.4769


Epoch 12/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 12/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 12/100 - Train Loss: 26.8585, Val Loss: 27.6414


Epoch 13/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 13/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 13/100 - Train Loss: 26.7065, Val Loss: 27.3428


Epoch 14/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 14/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 14/100 - Train Loss: 26.6307, Val Loss: 27.1432


Epoch 15/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 15/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 15/100 - Train Loss: 26.4983, Val Loss: 27.1025


Epoch 16/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 16/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 16/100 - Train Loss: 26.3634, Val Loss: 27.0088


Epoch 17/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 17/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 17/100 - Train Loss: 26.3097, Val Loss: 26.7735


Epoch 18/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 18/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 18/100 - Train Loss: 26.1053, Val Loss: 26.6185


Epoch 19/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 19/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 19/100 - Train Loss: 26.0347, Val Loss: 26.3767


Epoch 20/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 20/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 20/100 - Train Loss: 26.0775, Val Loss: 26.5613


Epoch 21/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 21/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 21/100 - Train Loss: 25.9186, Val Loss: 26.2201


Epoch 22/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 22/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 22/100 - Train Loss: 25.8873, Val Loss: 26.3287


Epoch 23/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 23/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 23/100 - Train Loss: 25.5882, Val Loss: 26.3128


Epoch 24/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 24/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 24/100 - Train Loss: 25.3924, Val Loss: 26.2767


Epoch 25/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 25/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 25/100 - Train Loss: 25.4129, Val Loss: 26.2085


Epoch 26/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 26/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 26/100 - Train Loss: 25.2674, Val Loss: 26.0147


Epoch 27/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 27/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 27/100 - Train Loss: 25.2650, Val Loss: 25.9559


Epoch 28/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 28/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 28/100 - Train Loss: 25.2870, Val Loss: 25.8851


Epoch 29/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 29/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 29/100 - Train Loss: 25.2115, Val Loss: 25.9243


Epoch 30/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 30/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 30/100 - Train Loss: 25.1279, Val Loss: 25.8851


Epoch 31/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 31/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 31/100 - Train Loss: 25.1053, Val Loss: 25.9244


Epoch 32/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 32/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 32/100 - Train Loss: 25.0694, Val Loss: 25.8184


Epoch 33/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 33/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 33/100 - Train Loss: 25.0685, Val Loss: 25.8166


Epoch 34/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 34/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 34/100 - Train Loss: 24.9376, Val Loss: 25.5521


Epoch 35/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 35/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 35/100 - Train Loss: 24.8400, Val Loss: 25.4831


Epoch 36/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 36/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 36/100 - Train Loss: 24.7861, Val Loss: 25.6449


Epoch 37/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 37/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 37/100 - Train Loss: 24.7245, Val Loss: 25.4636


Epoch 38/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 38/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 38/100 - Train Loss: 24.6651, Val Loss: 25.5907


Epoch 39/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 39/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 39/100 - Train Loss: 24.6969, Val Loss: 25.5058


Epoch 40/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 40/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 40/100 - Train Loss: 24.5769, Val Loss: 25.5022


Epoch 41/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 41/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 41/100 - Train Loss: 24.5955, Val Loss: 25.4710


Epoch 42/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 42/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 42/100 - Train Loss: 24.5575, Val Loss: 25.3362


Epoch 43/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 43/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 43/100 - Train Loss: 24.6948, Val Loss: 25.2800


Epoch 44/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 44/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 44/100 - Train Loss: 24.4861, Val Loss: 25.3850


Epoch 45/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 45/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 45/100 - Train Loss: 24.5668, Val Loss: 25.4200


Epoch 46/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 46/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 46/100 - Train Loss: 24.4200, Val Loss: 25.2372


Epoch 47/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 47/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 47/100 - Train Loss: 24.4170, Val Loss: 25.3629


Epoch 48/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 48/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 48/100 - Train Loss: 24.4344, Val Loss: 25.1496


Epoch 49/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 49/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 49/100 - Train Loss: 24.2560, Val Loss: 25.3213


Epoch 50/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 50/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 50/100 - Train Loss: 24.2592, Val Loss: 25.3048


Epoch 51/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 51/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 51/100 - Train Loss: 24.3512, Val Loss: 25.1623


Epoch 52/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 52/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 52/100 - Train Loss: 24.3010, Val Loss: 25.2254


Epoch 53/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 53/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 53/100 - Train Loss: 24.3155, Val Loss: 25.0416


Epoch 54/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 54/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 54/100 - Train Loss: 24.2988, Val Loss: 25.0596


Epoch 55/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 55/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 55/100 - Train Loss: 24.2856, Val Loss: 25.2161


Epoch 56/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 56/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 56/100 - Train Loss: 24.2683, Val Loss: 25.1900


Epoch 57/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 57/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 57/100 - Train Loss: 24.2377, Val Loss: 25.2552


Epoch 58/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 58/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 58/100 - Train Loss: 24.2006, Val Loss: 25.0396


Epoch 59/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 59/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 59/100 - Train Loss: 24.1954, Val Loss: 25.0942


Epoch 60/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 60/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 60/100 - Train Loss: 24.2669, Val Loss: 25.0539


Epoch 61/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 61/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 61/100 - Train Loss: 24.1058, Val Loss: 25.2674


Epoch 62/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 62/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 62/100 - Train Loss: 24.1112, Val Loss: 25.2206


Epoch 63/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 63/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 63/100 - Train Loss: 24.1716, Val Loss: 25.1689


Epoch 64/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 64/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 64/100 - Train Loss: 24.1661, Val Loss: 25.1379


Epoch 65/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 65/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 65/100 - Train Loss: 24.0650, Val Loss: 25.2085


Epoch 66/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 66/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 66/100 - Train Loss: 24.0350, Val Loss: 25.0483


Epoch 67/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 67/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 67/100 - Train Loss: 24.0730, Val Loss: 25.3070


Epoch 68/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 68/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 68/100 - Train Loss: 24.1355, Val Loss: 25.0039


Epoch 69/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 69/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 69/100 - Train Loss: 24.0631, Val Loss: 25.0170


Epoch 70/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 70/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 70/100 - Train Loss: 24.0739, Val Loss: 24.9515


Epoch 71/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 71/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 71/100 - Train Loss: 23.9929, Val Loss: 25.0110


Epoch 72/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 72/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 72/100 - Train Loss: 24.0815, Val Loss: 25.3643


Epoch 73/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 73/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 73/100 - Train Loss: 24.1001, Val Loss: 25.0560


Epoch 74/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 74/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 74/100 - Train Loss: 24.1314, Val Loss: 25.1689


Epoch 75/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 75/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 75/100 - Train Loss: 24.0726, Val Loss: 25.0957


Epoch 76/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 76/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 76/100 - Train Loss: 23.9955, Val Loss: 25.3014


Epoch 77/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 77/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 77/100 - Train Loss: 23.9675, Val Loss: 24.8851


Epoch 78/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 78/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 78/100 - Train Loss: 24.0027, Val Loss: 25.1828


Epoch 79/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 79/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 79/100 - Train Loss: 24.0249, Val Loss: 25.0611


Epoch 80/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 80/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 80/100 - Train Loss: 23.9345, Val Loss: 25.0500


Epoch 81/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 81/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 81/100 - Train Loss: 23.9534, Val Loss: 25.1758


Epoch 82/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 82/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 82/100 - Train Loss: 23.9211, Val Loss: 25.1321


Epoch 83/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 83/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 83/100 - Train Loss: 23.9652, Val Loss: 25.0068


Epoch 84/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 84/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 84/100 - Train Loss: 24.0647, Val Loss: 25.0435


Epoch 85/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 85/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 85/100 - Train Loss: 24.0482, Val Loss: 25.0081


Epoch 86/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 86/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 86/100 - Train Loss: 24.0288, Val Loss: 24.9149


Epoch 87/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 87/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 87/100 - Train Loss: 23.8875, Val Loss: 24.9985


Epoch 88/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 88/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 88/100 - Train Loss: 23.8646, Val Loss: 24.8358


Epoch 89/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 89/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 89/100 - Train Loss: 23.8263, Val Loss: 24.8198


Epoch 90/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 90/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 90/100 - Train Loss: 23.8899, Val Loss: 24.9788


Epoch 91/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 91/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 91/100 - Train Loss: 23.9541, Val Loss: 25.2013


Epoch 92/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 92/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 92/100 - Train Loss: 23.9454, Val Loss: 25.0246


Epoch 93/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 93/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 93/100 - Train Loss: 23.9486, Val Loss: 25.0380


Epoch 94/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 94/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 94/100 - Train Loss: 23.9132, Val Loss: 24.9594


Epoch 95/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 95/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 95/100 - Train Loss: 23.9302, Val Loss: 25.0176


Epoch 96/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 96/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 96/100 - Train Loss: 23.9402, Val Loss: 25.1863


Epoch 97/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 97/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 97/100 - Train Loss: 23.8355, Val Loss: 24.9286


Epoch 98/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 98/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 98/100 - Train Loss: 23.8277, Val Loss: 25.1454


Epoch 99/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 99/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 99/100 - Train Loss: 23.7485, Val Loss: 24.8399


Epoch 100/100 - Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 100/100 - Validation:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 100/100 - Train Loss: 23.8357, Val Loss: 25.1559
