In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch

import pickle

In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pandas as pd

In [4]:
device = torch.device("cuda")

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

class BayesEuroSATCNN(PyroModule):
    def __init__(
        self,
        num_classes: int = 10,
        device: torch.device = torch.device("cuda"),
        activation: str = 'relu',
        prior_dist: str = 'gaussian',
        mu: float = 0.0,
        b: float = 1.0,
        prior_params: dict = None
    ):
        super().__init__()
        self.device = device

        # --- Activation setup ---
        if isinstance(activation, str):
            act_map = {
                'relu': F.relu,
                'tanh': torch.tanh,
                'sigmoid': torch.sigmoid,
                'sinusoidal': torch.sin,
                'relu6': F.relu6,
                'leaky_relu': F.leaky_relu,
                'selu': F.selu,
                'wg': self.actWG,
                'rwg': self.actRWG,
            }
            self.activation_fn = act_map[activation]
        elif callable(activation):
            self.activation_fn = activation
        else:
            raise ValueError("activation must be a string or callable")

        # --- Prior setup ---
        self.prior_dist = prior_dist
        params = {'mu': mu, 'b': b} if prior_params is None else prior_params
        self.prior_mu = torch.tensor(params['mu'], device=device)
        self.prior_b  = torch.tensor(params['b'], device=device)
        print(f"[INFO] Using prior: {self.prior_dist} (μ={self.prior_mu.item()}, b={self.prior_b.item()})")

        # --- Convolutional feature extractor ---
        # EuroSAT RGB → 3‑channel 64×64 inputs
        self.conv1 = PyroModule[nn.Conv2d](in_channels=3,  out_channels=32, kernel_size=3, padding=1)
        self.conv1.weight = PyroSample(self._make_prior([32, 3, 3, 3]))
        self.conv1.bias   = PyroSample(self._make_prior([32]))

        self.conv2 = PyroModule[nn.Conv2d](in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv2.weight = PyroSample(self._make_prior([64, 32, 3, 3]))
        self.conv2.bias   = PyroSample(self._make_prior([64]))

        self.pool = nn.MaxPool2d(2, 2)

        # --- Final linear layer for 10 classes ---
        # After two pools: 64×64 → 32×32 → 16×16 feature maps
        feat_dim = 64 * 16 * 16
        self.fc1 = PyroModule[nn.Linear](feat_dim, num_classes)
        self.fc1.weight = PyroSample(self._make_prior([num_classes, feat_dim]))
        self.fc1.bias   = PyroSample(self._make_prior([num_classes]))

    def actWG(self, x, alpha: float = 1.0):
        return x * torch.exp(-alpha * x ** 2)

    def actRWG(self, x, alpha: float = 1.0):
        wg = x * torch.exp(-alpha * x ** 2)
        return torch.clamp(wg, min=0.0)

    def _make_prior(self, shape):
        if self.prior_dist == 'gaussian':
            base = dist.Normal(self.prior_mu, self.prior_b)
        elif self.prior_dist == 'laplace':
            base = dist.Laplace(self.prior_mu, self.prior_b)
        elif self.prior_dist == 'uniform':
            base = dist.Uniform(-self.prior_b, self.prior_b)
        else:
            raise ValueError(f"Unsupported prior: {self.prior_dist}")
        return base.expand(shape).to_event(len(shape))

    def forward(self, x, y=None):
        # Feature extraction
        x = self.activation_fn(self.conv1(x))
        x = self.pool(x)
        x = self.activation_fn(self.conv2(x))
        x = self.pool(x)

        # Class logits
        x = x.view(x.size(0), -1)
        logits = self.fc1(x)

        # Observation plate
        if y is not None:
            with pyro.plate("data", x.size(0)):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        return logits


In [6]:
eurosat_mean = [0.344, 0.380, 0.408]
eurosat_std  = [0.190, 0.137, 0.115]

In [7]:
def load_data(batch_size=54):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=eurosat_mean, 
                             std=eurosat_std)
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=True)

    torch.manual_seed(42)

    #train_size = int(0.8 * len(dataset))
    #test_size = len(dataset) - train_size
    #train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

In [None]:
def train_svi_with_stats(
    model,
    guide,
    svi,
    train_loader,
    device,
    num_epochs=10,
    save_epochs=None,
    save_dir='results',
    model_filename_pattern='model_{activation}_{prior}_epoch_{epoch}_{timestamp}.pth',
    guide_filename_pattern='guide_{activation}_{prior}_epoch_{epoch}_{timestamp}.pth',
    param_store_filename_pattern='param_store_{activation}_{prior}_epoch_{epoch}_{timestamp}.pkl',
    accuracies_filename_pattern='accuracy_results_{activation}_{prior}_{timestamp}.csv',
    losses_filename_pattern='losses_{activation}_{prior}_{timestamp}.csv',
    model_config_filename_pattern='config_{activation}_{prior}_{timestamp}.json'
):
    """
    Train the SVI model, track losses/accuracies, and
    save artifacts only when accuracy improves, naming files
    like `model_relu_gaussian_epoch_3.pth`.
    """
    
    # Pull names off the model if available, else fall back
    #act_name  = getattr(model, 'activation', getattr(model, 'activation_name', 'act'))
    act_name = model.activation_fn.__name__ if hasattr(model.activation_fn, '__name__') else str(model.activation_fn)
    prior_name = getattr(model, 'prior_dist', 'prior')
    timestamp = time.strftime("%Y%m%d_%H%M%S")

    os.makedirs(save_dir, exist_ok=True)
    save_epochs = set(save_epochs or range(1, num_epochs+1))

    pyro.clear_param_store()
    model.to(device)

    epoch_losses, epoch_accuracies, accuracy_epochs = [], [], []
    weight_stats = {'epochs': [], 'means': [], 'stds': []}
    bias_stats   = {'epochs': [], 'means': [], 'stds': []}
    best_acc = 0.0

    for epoch in range(1, num_epochs+1):
        model.train()
        total_loss = 0.0
        batches = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            total_loss += svi.step(images, labels)
            batches += 1

        avg_loss = total_loss / batches
        epoch_losses.append(avg_loss)
        print(f"Epoch {epoch} - ELBO Loss: {avg_loss:.4f}")

        if epoch == 1 or epoch % 10 == 0 or epoch == num_epochs:
            model.eval(); guide.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for images, labels in tqdm(train_loader, desc=f"Acc check epoch {epoch}"):
                    images, labels = images.to(device), labels.to(device)
                    trace = pyro.poutine.trace(guide).get_trace(images)
                    replayed = pyro.poutine.replay(model, trace=trace)
                    logits = replayed(images)
                    preds = torch.argmax(logits, dim=1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)

            acc = correct/total
            epoch_accuracies.append(acc); accuracy_epochs.append(epoch)
            print(f"Epoch {epoch} - Train Acc: {acc*100:.2f}%")

            # record stats...
            w_means, w_stds, b_means, b_stds = [], [], [], []
            for name, param in pyro.get_param_store().items():
                if 'loc' in name:
                    w_means.append(param.mean().item()); w_stds.append(param.std().item())
                elif 'scale' in name:
                    b_means.append(param.mean().item()); b_stds.append(param.std().item())
            weight_stats['epochs'].append(epoch)
            weight_stats['means'].append(w_means)
            weight_stats['stds'].append(w_stds)
            bias_stats['epochs'].append(epoch)
            bias_stats['means'].append(b_means)
            bias_stats['stds'].append(b_stds)

            #for name, param in pyro.get_param_store().items():
            #    if 'loc' in name or 'scale' in name:
            #        print(f"{name}: {param.detach().cpu().numpy()}")

            # only save when accuracy improves
            if acc > best_acc:
                best_acc = acc
                fname_model = model_filename_pattern.format(activation=act_name, prior=prior_name, epoch="best", timestamp=timestamp)
                fname_guide = guide_filename_pattern.format(activation=act_name, prior=prior_name, epoch="best", timestamp=timestamp)
                fname_ps    = param_store_filename_pattern.format(activation=act_name, prior=prior_name, epoch="best", timestamp=timestamp)

                torch.save(model.state_dict(), os.path.join(save_dir, fname_model))
                torch.save(guide.state_dict(), os.path.join(save_dir, fname_guide))
                pyro.get_param_store().save(os.path.join(save_dir, fname_ps))
                print(f"  ↳ Saved: {fname_model}, {fname_guide}, {fname_ps}")

    # save losses per epoch in a csv file, with consistent file naming
    accuracies_df = pd.DataFrame({
        'epoch': accuracy_epochs,
        'accuracy': epoch_accuracies
    })
    accuracies_df.to_csv(os.path.join(save_dir,accuracies_filename_pattern.format(activation=act_name, prior=prior_name, timestamp=timestamp)), index=False)

    loss_df = pd.DataFrame({
        'epoch': list(range(1, epoch + 1)),
        'loss': epoch_losses
    })
    loss_df.to_csv(os.path.join(save_dir,losses_filename_pattern.format(activation=act_name, prior=prior_name, timestamp=timestamp)), index=False)
            
    # save model configuration in a json file
    config = {
        'activation': act_name,
        'prior': prior_name,
        'num_epochs': num_epochs,
        'best_accuracy_at_epoch': accuracy_epochs[np.argmax(epoch_accuracies)],
        'best_accuracy': best_acc,
        'batch_size': train_loader.batch_size,
        'train_size': len(train_loader.dataset),
        'prior_params': {
            'mu': model.prior_mu.item(),
            'b': model.prior_b.item()
        },
    }
    config_filename = model_config_filename_pattern.format(activation=act_name, prior=prior_name, timestamp=timestamp)

    with open(os.path.join(save_dir, config_filename), 'w') as f:
        import json
        json.dump(config, f, indent=4)
        print(f"Configuration saved to {config_filename}")

    return epoch_losses, epoch_accuracies, accuracy_epochs, weight_stats, bias_stats, os.path.join(save_dir, fname_model), os.path.join(save_dir, fname_guide), os.path.join(save_dir, fname_ps), timestamp


In [9]:
def plot_training_results_with_stats(losses, accuracies, accuracy_epochs, weight_stats, bias_stats, act_name, prior_name, timestamp):
    """Plot training results with weight and bias statistics"""
    plt.figure(figsize=(16, 12))
    
    # Plot 1: Training Loss
    plt.subplot(2, 2, 1)
    plt.plot(range(1, len(losses) + 1), losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('ELBO Loss')
    plt.grid(True)
    
    # Plot 2: Training Accuracy
    plt.subplot(2, 2, 2)
    plt.plot(accuracy_epochs, accuracies, 'o-')
    plt.title('Training Accuracy (Every 10 Epochs)')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)
    
    # Plot 3: Weight Statistics Boxplot
    plt.subplot(2, 2, 3)
    weight_data = []
    weight_labels = []
    
    for i, epoch in enumerate(weight_stats['epochs']):
        # Combine means and stds for this epoch
        epoch_data = weight_stats['means'][i] + weight_stats['stds'][i]
        weight_data.append(epoch_data)
        weight_labels.append(f'Epoch {epoch}')
    
    if weight_data:
        bp1 = plt.boxplot(weight_data, labels=weight_labels, patch_artist=True)
        for patch in bp1['boxes']:
            patch.set_facecolor('lightblue')
    
    plt.title('LOC Statistics Distribution')
    plt.xlabel('Epoch')
    plt.ylabel('LOC Values')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Bias Statistics Boxplot
    plt.subplot(2, 2, 4)
    bias_data = []
    bias_labels = []
    
    for i, epoch in enumerate(bias_stats['epochs']):
        # Combine means and stds for this epoch
        epoch_data = bias_stats['means'][i] + bias_stats['stds'][i]
        bias_data.append(epoch_data)
        bias_labels.append(f'Epoch {epoch}')
    
    if bias_data:
        bp2 = plt.boxplot(bias_data, tick_labels=bias_labels, patch_artist=True)
        for patch in bp2['boxes']:
            patch.set_facecolor('lightcoral')
    
    plt.title('SCALE Statistics Distribution')
    plt.xlabel('Epoch')
    plt.ylabel('SCALE Values')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'results_GP_eurosat_newslate/bayesian_cnn_training_results_{act_name}_{prior_name}_{timestamp}.png')
    #plt.show()

In [10]:
from tqdm import tqdm

In [11]:
import numpy as np
from sklearn.metrics import confusion_matrix


def predict_data(model, loader_of_interest, num_samples=10):
    model.eval()
    guide.eval()

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in tqdm(loader_of_interest, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            logits_mc = torch.zeros(num_samples, images.size(0), model.fc2.out_features).to(device)

            for i in range(num_samples):
                guide_trace = pyro.poutine.trace(guide).get_trace(images)
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    return all_labels, all_predictions

def save_predictions_to_csv(labels, predictions, filename='predictions.csv'):
    df = pd.DataFrame({'True Label': labels, 'Predicted Label': predictions})
    df.to_csv(filename, index=False)
    print(f"Predictions saved to {filename}")

def send_telegram_message(title, message):
    load_dotenv('.env')
    token = os.getenv('TELEGRAM_BOT_TOKEN')

    try:
        response = requests.post(f'https://api.telegram.org/bot{token}/sendMessage', data={
            'chat_id': os.getenv('TELEGRAM_CHAT_ID'),
            'text': f'{title}\n{message}',
            #'parse_mode': 'Markdown'
        })
    except requests.exceptions.RequestException as e:
        print(f"Error sending message: {e}")
        return None

In [12]:
from pyro.infer.autoguide import AutoNormal
from pyro.optim import ClippedAdam
from tqdm import tqdm
import pandas as pd

In [13]:
from dotenv import load_dotenv
import requests
import os

In [14]:
num_classes = 10

In [None]:
#act_fn_list = ['gaussian', 'laplace', 'uniform']
#prior_list = ['tanh','sigmoid','relu','sinusoidal','relu6','wg','rwg']
#b_list = [10.0, 1.0, 0.1]

act_fn_list = ['gaussian']
prior_list = ['tanh']
b_list = [1.0]

#count how many combinations we have
total_combinations = len(act_fn_list) * len(prior_list) * len(b_list)
print(f"Total combinations to run: {total_combinations}")

experiment_number = 0

#'rwg','wg',
for activation_iter in act_fn_list:
    for prior_iter in prior_list:
        for b_iter in b_list:

            experiment_number += 1
            experiment_time_start = time.time()
            #send telegram message to announce the start of the experiment (x/total combinations)
            send_telegram_message(
                title=f"Experiment {experiment_number}/{total_combinations}",
                message=f"Running with activation={activation_iter}, prior={prior_iter}, b={b_iter}"
            )

            pyro.clear_param_store()

            # print log to annoounce what experiment is running
            print("==========================================")
            print(f"Running experiment with activation={prior_iter}, prior={activation_iter}, b={b_iter}")
            print("==========================================")
            bayesian_model = BayesEuroSATCNN(num_classes,
                    device,
                    activation=prior_iter,
                    prior_dist=activation_iter,
                    mu = 0.0,
                    b= b_iter,
                    #prior_params={'mu': 0.0, 'b': b_iter})
                    )
            
            # 1) construct your guide so its locs start at p(w).mean=0
            #guide = AutoDiagonalNormal(
            #    bayesian_model,
            #    init_loc_fn=init_to_median(num_samples=1),   # all μ_q ← prior mean (0)
            #    init_scale=0.1               # set initial σ_q=0.1
            #)

            guide = AutoNormal(bayesian_model, init_scale=0.05).to(device)
            optimizer = ClippedAdam({"lr": 1e-3})

            #optimizer = Adam({"lr": 1e-3,
            #                  "weight_decay": 1e-4,
            #                  })  # Increased from 1e-4 to 1e-3, weight decay added
            svi = pyro.infer.SVI(model=bayesian_model,
                                guide=guide,
                                optim=optimizer,
                                loss=pyro.infer.Trace_ELBO(num_particles=1,
                                                            )) #TODO

            pyro.clear_param_store()

            # Ensure model and guide are on the correct device
            bayesian_model.to(device)
            guide.to(device)

            train_loader, test_loader = load_data(batch_size=54)
            
            losses, accuracies, accuracy_epochs, weight_stats, bias_stats, best_model_path, best_guide_path, best_param_store_path, experiment_timestamp = train_svi_with_stats(
            bayesian_model,
            guide,
            svi,
            train_loader,
            device,
            num_epochs=100,
            save_epochs=None,
            save_dir='results_GP_eurosat_newslate')
            
            act_name = bayesian_model.activation_fn.__name__ if hasattr(bayesian_model.activation_fn, '__name__') else str(bayesian_model.activation_fn)
            prior_name = getattr(bayesian_model, 'prior_dist', 'prior')

            plot_training_results_with_stats(losses, accuracies, accuracy_epochs, weight_stats, bias_stats, act_name, prior_name, experiment_timestamp)

            all_labels, all_predictions = predict_data(bayesian_model, test_loader, num_samples=10)
            cm = confusion_matrix(all_labels, all_predictions)
            #print accuracy from confusion matrix
            accuracy = np.trace(cm) / np.sum(cm)
            print(f"Accuracy from confusion matrix: {accuracy * 100:.6f}%")

            experiment_time_finish = time.time()

            save_predictions_to_csv(all_labels, all_predictions, os.path.join('results_GP_eurosat_newslate', f'predictions_{act_name}_{prior_name}_{experiment_timestamp}_{accuracy * 100:.0f}.csv'))

            send_telegram_message(
                title=f"Experiment {experiment_number}/{total_combinations} Finished",
                message=f"Activation: {prior_iter}, Prior: {activation_iter}, b: {b_iter}\n"
                        f"Best Model Test Accuracy: {accuracy * 100:.2f}%\n"
                        f"Time taken: {experiment_time_finish - experiment_time_start:.2f} seconds"
            )

Total combinations to run: 1
Running experiment with activation=tanh, prior=gaussian, b=1.0
[INFO] Using prior: gaussian (μ=0.0, b=1.0)


Epoch 1/100: 100%|██████████| 360/360 [01:10<00:00,  5.13it/s]


Epoch 1 - ELBO Loss: 426239.6237


Acc check epoch 1: 100%|██████████| 360/360 [00:04<00:00, 79.46it/s]


Epoch 1 - Train Acc: 28.37%
AutoNormal.locs.conv1.weight: [[[[-3.15040583e-03 -4.60303668e-03 -4.38435748e-03]
   [-4.77947108e-03 -7.55924778e-03 -6.01017335e-03]
   [-7.50679930e-04 -1.25478266e-03 -2.45365640e-03]]

  [[ 6.36941846e-03  3.74078751e-03  1.10086438e-03]
   [ 1.67038245e-03 -2.65217409e-03 -2.81457324e-03]
   [ 2.98527488e-03  2.14964594e-03  3.03898286e-03]]

  [[ 8.50361027e-03  7.22028315e-03  6.97024819e-03]
   [ 4.70737647e-03  7.08209176e-04  2.87896162e-03]
   [ 8.84383544e-03  9.31064971e-03  4.01764363e-03]]]


 [[[-1.56069500e-02 -1.83950663e-02 -1.93530265e-02]
   [-1.55060356e-02 -1.95964072e-02 -2.05669124e-02]
   [-2.13207863e-02 -2.20213160e-02 -2.09621210e-02]]

  [[-3.38724791e-03 -4.74546338e-03 -6.40644087e-03]
   [-1.15953409e-03 -6.56486861e-03 -9.13854595e-03]
   [-7.70606752e-03 -7.14864396e-03 -9.10198130e-03]]

  [[ 1.34308571e-02  8.20759218e-03  4.60528117e-03]
   [ 1.46677485e-02  6.64946856e-03  2.26100310e-04]
   [ 6.99142739e-03  2.004304

Epoch 2/100: 100%|██████████| 360/360 [00:16<00:00, 22.46it/s]


Epoch 2 - ELBO Loss: 364563.0573


Epoch 3/100: 100%|██████████| 360/360 [00:15<00:00, 22.65it/s]


Epoch 3 - ELBO Loss: 305658.5925


Epoch 4/100: 100%|██████████| 360/360 [00:15<00:00, 22.98it/s]


Epoch 4 - ELBO Loss: 250734.5808


Epoch 5/100: 100%|██████████| 360/360 [00:15<00:00, 22.76it/s]


Epoch 5 - ELBO Loss: 201214.5640


Epoch 6/100: 100%|██████████| 360/360 [00:16<00:00, 22.33it/s]


Epoch 6 - ELBO Loss: 158644.5363


Epoch 7/100: 100%|██████████| 360/360 [00:15<00:00, 22.57it/s]


Epoch 7 - ELBO Loss: 123664.2812


Epoch 8/100: 100%|██████████| 360/360 [00:15<00:00, 22.94it/s]


Epoch 8 - ELBO Loss: 96374.8320


Epoch 9/100: 100%|██████████| 360/360 [00:15<00:00, 22.59it/s]


Epoch 9 - ELBO Loss: 75854.2445


Epoch 10/100: 100%|██████████| 360/360 [00:16<00:00, 22.32it/s]


Epoch 10 - ELBO Loss: 60699.1015


Acc check epoch 10: 100%|██████████| 360/360 [00:04<00:00, 78.64it/s]


Epoch 10 - Train Acc: 34.99%
AutoNormal.locs.conv1.weight: [[[[ 1.22538373e-01  5.00336811e-02  4.94576320e-02]
   [ 4.26790640e-02  2.70796567e-02  6.52807578e-02]
   [ 9.59905386e-02  5.16397543e-02  6.92695156e-02]]

  [[ 1.78863816e-02 -5.64005561e-02 -4.99288589e-02]
   [-6.70776665e-02 -1.04964003e-01 -4.20924872e-02]
   [-3.70309316e-02 -5.04002795e-02 -3.38680074e-02]]

  [[ 7.92362839e-02  2.03445069e-02 -7.80166686e-03]
   [-1.03947818e-02 -3.03144306e-02 -2.68779881e-02]
   [-2.59881141e-03  2.12667678e-02 -4.32608509e-03]]]


 [[[-1.36423372e-02 -2.93229129e-02 -3.62126194e-02]
   [-2.74845064e-02 -6.55072778e-02 -4.00895178e-02]
   [-1.84187461e-02 -7.02033490e-02 -3.24217565e-02]]

  [[ 6.20910227e-02  5.33972904e-02  3.29987220e-02]
   [ 7.47554526e-02  2.91529521e-02  4.44396138e-02]
   [ 4.53019813e-02  8.21200665e-03  1.09870359e-02]]

  [[ 7.05756843e-02  1.76668633e-02  1.73106592e-03]
   [ 7.03904852e-02  2.69684270e-02  4.29222137e-02]
   [ 5.65579347e-02  2.30550

Epoch 11/100: 100%|██████████| 360/360 [00:16<00:00, 22.41it/s]


Epoch 11 - ELBO Loss: 49834.0204


Epoch 12/100: 100%|██████████| 360/360 [00:15<00:00, 22.70it/s]


Epoch 12 - ELBO Loss: 41748.2254


Epoch 13/100: 100%|██████████| 360/360 [00:15<00:00, 22.93it/s]


Epoch 13 - ELBO Loss: 35795.4365


Epoch 14/100: 100%|██████████| 360/360 [00:15<00:00, 22.75it/s]


Epoch 14 - ELBO Loss: 31275.1866


Epoch 15/100: 100%|██████████| 360/360 [00:16<00:00, 22.35it/s]


Epoch 15 - ELBO Loss: 27648.7102


Epoch 16/100: 100%|██████████| 360/360 [00:15<00:00, 22.65it/s]


Epoch 16 - ELBO Loss: 24866.8276


Epoch 17/100: 100%|██████████| 360/360 [00:15<00:00, 23.36it/s]


Epoch 17 - ELBO Loss: 22675.6624


Epoch 18/100: 100%|██████████| 360/360 [00:15<00:00, 22.82it/s]


Epoch 18 - ELBO Loss: 20861.6378


Epoch 19/100: 100%|██████████| 360/360 [00:15<00:00, 22.57it/s]


Epoch 19 - ELBO Loss: 19232.1917


Epoch 20/100: 100%|██████████| 360/360 [00:15<00:00, 22.58it/s]


Epoch 20 - ELBO Loss: 17863.2582


Acc check epoch 20: 100%|██████████| 360/360 [00:04<00:00, 77.40it/s]


Epoch 20 - Train Acc: 36.49%
AutoNormal.locs.conv1.weight: [[[[ 0.15984355 -0.00111925  0.00830844]
   [ 0.01694545 -0.04709719 -0.01032929]
   [ 0.07380927  0.04806075  0.01822126]]

  [[-0.00670626 -0.10116804 -0.11727893]
   [-0.11944399 -0.19928496 -0.14646262]
   [-0.05918744 -0.10154036 -0.14022031]]

  [[ 0.17336665  0.05070613 -0.01951512]
   [-0.00615739 -0.04736277 -0.06455131]
   [ 0.03810386  0.0148179  -0.05410821]]]


 [[[-0.0320844  -0.04402483 -0.08508433]
   [-0.02615764 -0.0968451  -0.07687359]
   [ 0.00818887 -0.11787707 -0.1167007 ]]

  [[ 0.13902476  0.14142253  0.07663164]
   [ 0.14425921  0.09484088  0.06784606]
   [ 0.14263804  0.05906207  0.02879169]]

  [[ 0.0880913   0.06100316 -0.05603813]
   [ 0.13845842  0.0448976  -0.0128053 ]
   [ 0.08295153 -0.01462264 -0.02106589]]]


 [[[ 0.1077979   0.09779193  0.12914835]
   [ 0.05305525  0.0455419   0.04714285]
   [-0.03250789 -0.00994402  0.01387992]]

  [[ 0.01205577  0.03368433  0.01165396]
   [-0.03471426 -0.02

Epoch 21/100: 100%|██████████| 360/360 [00:15<00:00, 23.04it/s]


Epoch 21 - ELBO Loss: 16840.4472


Epoch 22/100: 100%|██████████| 360/360 [00:15<00:00, 23.26it/s]


Epoch 22 - ELBO Loss: 15903.1912


Epoch 23/100: 100%|██████████| 360/360 [00:15<00:00, 22.75it/s]


Epoch 23 - ELBO Loss: 15199.8260


Epoch 24/100: 100%|██████████| 360/360 [00:15<00:00, 22.77it/s]


Epoch 24 - ELBO Loss: 14343.7332


Epoch 25/100: 100%|██████████| 360/360 [00:15<00:00, 22.69it/s]


Epoch 25 - ELBO Loss: 13746.8213


Epoch 26/100: 100%|██████████| 360/360 [00:15<00:00, 22.82it/s]


Epoch 26 - ELBO Loss: 13179.0840


Epoch 27/100: 100%|██████████| 360/360 [00:16<00:00, 22.49it/s]


Epoch 27 - ELBO Loss: 12795.5539


Epoch 28/100: 100%|██████████| 360/360 [00:15<00:00, 22.72it/s]


Epoch 28 - ELBO Loss: 12189.9589


Epoch 29/100: 100%|██████████| 360/360 [00:16<00:00, 22.42it/s]


Epoch 29 - ELBO Loss: 11935.4698


Epoch 30/100: 100%|██████████| 360/360 [00:16<00:00, 22.25it/s]


Epoch 30 - ELBO Loss: 11571.1755


Acc check epoch 30: 100%|██████████| 360/360 [00:04<00:00, 79.09it/s]


Epoch 30 - Train Acc: 40.13%
AutoNormal.locs.conv1.weight: [[[[ 2.27704391e-01  4.47178110e-02  3.23417559e-02]
   [ 1.11633381e-02 -3.93557847e-02 -1.36062922e-02]
   [ 7.54822418e-02  5.60774505e-02  4.07940447e-02]]

  [[ 1.04737487e-02 -1.32735610e-01 -9.93182883e-02]
   [-1.77983791e-01 -2.53475070e-01 -2.15092093e-01]
   [-1.01954512e-01 -5.91899306e-02 -1.31367385e-01]]

  [[ 2.40120843e-01  9.56284404e-02 -3.80140706e-03]
   [ 2.21452098e-02 -4.64022458e-02 -7.86338300e-02]
   [ 5.47332577e-02  7.28482902e-02 -5.04626743e-02]]]


 [[[-1.56384259e-01 -1.54970914e-01 -1.74544305e-01]
   [-1.28111258e-01 -1.98545545e-01 -1.30280897e-01]
   [-3.23654935e-02 -1.83515519e-01 -1.65327653e-01]]

  [[ 1.51827112e-01  1.25195906e-01  1.28606096e-01]
   [ 1.87963158e-01  1.40012100e-01  1.01172328e-01]
   [ 1.97782248e-01  1.26350984e-01  1.02763064e-01]]

  [[ 4.44270633e-02  3.71177564e-03 -7.31325522e-02]
   [ 1.29638553e-01  3.77507396e-02 -1.68908983e-02]
   [ 1.17605641e-01  2.75206

Epoch 31/100: 100%|██████████| 360/360 [00:15<00:00, 22.57it/s]


Epoch 31 - ELBO Loss: 11200.9802


Epoch 32/100: 100%|██████████| 360/360 [00:15<00:00, 22.83it/s]


Epoch 32 - ELBO Loss: 10869.5001


Epoch 33/100: 100%|██████████| 360/360 [00:15<00:00, 22.58it/s]


Epoch 33 - ELBO Loss: 10673.0537


Epoch 34/100: 100%|██████████| 360/360 [00:16<00:00, 22.29it/s]


Epoch 34 - ELBO Loss: 10395.6926


Epoch 35/100: 100%|██████████| 360/360 [00:15<00:00, 22.71it/s]


Epoch 35 - ELBO Loss: 10067.5760


Epoch 36/100: 100%|██████████| 360/360 [00:16<00:00, 22.50it/s]


Epoch 36 - ELBO Loss: 10055.6541


Epoch 37/100: 100%|██████████| 360/360 [00:15<00:00, 22.76it/s]


Epoch 37 - ELBO Loss: 9611.1448


Epoch 38/100: 100%|██████████| 360/360 [00:15<00:00, 23.07it/s]


Epoch 38 - ELBO Loss: 9568.3739


Epoch 39/100: 100%|██████████| 360/360 [00:16<00:00, 22.36it/s]


Epoch 39 - ELBO Loss: 9428.3834


Epoch 40/100: 100%|██████████| 360/360 [00:15<00:00, 22.86it/s]


Epoch 40 - ELBO Loss: 9213.3796


Acc check epoch 40: 100%|██████████| 360/360 [00:04<00:00, 81.65it/s]


Epoch 40 - Train Acc: 41.58%
AutoNormal.locs.conv1.weight: [[[[ 2.70865083e-01  6.31359145e-02  5.60733117e-02]
   [ 4.00298908e-02 -1.00253060e-01  2.29485072e-02]
   [ 8.28337818e-02  5.46718165e-02  9.59727243e-02]]

  [[ 8.90051201e-02 -7.95306414e-02 -8.33498985e-02]
   [-1.86262473e-01 -2.54393488e-01 -2.21819237e-01]
   [-7.47048855e-02 -7.18032056e-03 -9.65306461e-02]]

  [[ 3.66560578e-01  1.82248071e-01  1.25178443e-02]
   [ 2.36832518e-02 -9.25949365e-02 -7.48957917e-02]
   [ 8.87788907e-02  8.30519423e-02 -4.79974672e-02]]]


 [[[-2.00187072e-01 -2.08092764e-01 -2.12889344e-01]
   [-1.46392837e-01 -2.33288810e-01 -1.66558519e-01]
   [-6.85415491e-02 -2.14651972e-01 -1.47434041e-01]]

  [[ 1.99683934e-01  2.25133017e-01  1.61699146e-01]
   [ 2.41500556e-01  1.84360445e-01  1.22931808e-01]
   [ 2.12525755e-01  1.88064799e-01  1.78892627e-01]]

  [[ 9.99856591e-02 -9.89380758e-03 -6.83429316e-02]
   [ 9.26114842e-02 -3.81682417e-03 -3.76260057e-02]
   [ 1.13594167e-01  9.68137

Epoch 41/100: 100%|██████████| 360/360 [00:15<00:00, 22.92it/s]


Epoch 41 - ELBO Loss: 8917.0368


Epoch 42/100: 100%|██████████| 360/360 [00:15<00:00, 22.71it/s]


Epoch 42 - ELBO Loss: 8929.9099


Epoch 43/100: 100%|██████████| 360/360 [00:15<00:00, 22.71it/s]


Epoch 43 - ELBO Loss: 8850.1840


Epoch 44/100: 100%|██████████| 360/360 [00:15<00:00, 22.67it/s]


Epoch 44 - ELBO Loss: 8685.9848


Epoch 45/100: 100%|██████████| 360/360 [00:15<00:00, 22.76it/s]


Epoch 45 - ELBO Loss: 8560.8159


Epoch 46/100: 100%|██████████| 360/360 [00:15<00:00, 22.71it/s]


Epoch 46 - ELBO Loss: 8390.2265


Epoch 47/100: 100%|██████████| 360/360 [00:15<00:00, 22.94it/s]


Epoch 47 - ELBO Loss: 8369.9890


Epoch 48/100: 100%|██████████| 360/360 [00:15<00:00, 22.67it/s]


Epoch 48 - ELBO Loss: 8467.2337


Epoch 49/100: 100%|██████████| 360/360 [00:15<00:00, 22.80it/s]


Epoch 49 - ELBO Loss: 8180.8182


Epoch 50/100: 100%|██████████| 360/360 [00:15<00:00, 22.87it/s]


Epoch 50 - ELBO Loss: 8141.6852


Acc check epoch 50: 100%|██████████| 360/360 [00:04<00:00, 80.30it/s]


Epoch 50 - Train Acc: 44.41%
AutoNormal.locs.conv1.weight: [[[[ 3.11416596e-01  5.78624830e-02  5.60062611e-03]
   [ 8.54225308e-02 -1.48804128e-01 -1.37575334e-02]
   [ 1.11967534e-01  7.24489689e-02  6.70661554e-02]]

  [[ 1.15984544e-01 -1.02353469e-01 -1.58473164e-01]
   [-1.89307734e-01 -3.07606369e-01 -3.13296348e-01]
   [-9.22455490e-02 -3.67851183e-02 -1.54019669e-01]]

  [[ 4.45301116e-01  1.82498619e-01 -2.13915017e-02]
   [-2.59054918e-03 -1.26612350e-01 -1.29139677e-01]
   [ 6.29156828e-02  5.95658980e-02 -8.92552435e-02]]]


 [[[-2.43256763e-01 -2.19543368e-01 -2.31299132e-01]
   [-2.23807395e-01 -2.78887153e-01 -2.65452057e-01]
   [-1.37344092e-01 -2.84439534e-01 -2.55973548e-01]]

  [[ 1.82436451e-01  2.49814719e-01  1.67482108e-01]
   [ 2.03001097e-01  2.34485000e-01  1.34474382e-01]
   [ 2.08052859e-01  1.87122434e-01  1.91547781e-01]]

  [[ 1.14358760e-01 -1.89862074e-03 -4.68311124e-02]
   [ 7.87129775e-02  2.15573758e-02 -4.51750457e-02]
   [ 1.39396757e-01  2.48771

Epoch 51/100: 100%|██████████| 360/360 [00:15<00:00, 22.58it/s]


Epoch 51 - ELBO Loss: 8080.7246


Epoch 52/100: 100%|██████████| 360/360 [00:15<00:00, 23.00it/s]


Epoch 52 - ELBO Loss: 8152.2794


Epoch 53/100: 100%|██████████| 360/360 [00:15<00:00, 22.81it/s]


Epoch 53 - ELBO Loss: 7914.7891


Epoch 54/100: 100%|██████████| 360/360 [00:25<00:00, 14.33it/s]


Epoch 54 - ELBO Loss: 7838.0704


Epoch 55/100: 100%|██████████| 360/360 [00:35<00:00, 10.28it/s]


Epoch 55 - ELBO Loss: 7748.2546


Epoch 56/100: 100%|██████████| 360/360 [00:35<00:00, 10.21it/s]


Epoch 56 - ELBO Loss: 7754.8995


Epoch 57/100: 100%|██████████| 360/360 [00:34<00:00, 10.31it/s]


Epoch 57 - ELBO Loss: 7713.6336


Epoch 58/100: 100%|██████████| 360/360 [00:35<00:00, 10.26it/s]


Epoch 58 - ELBO Loss: 7652.4907


Epoch 59/100: 100%|██████████| 360/360 [00:35<00:00, 10.24it/s]


Epoch 59 - ELBO Loss: 7565.9817


Epoch 60/100: 100%|██████████| 360/360 [00:35<00:00, 10.21it/s]


Epoch 60 - ELBO Loss: 7636.3323


Acc check epoch 60: 100%|██████████| 360/360 [00:12<00:00, 29.99it/s]


Epoch 60 - Train Acc: 43.90%
AutoNormal.locs.conv1.weight: [[[[ 2.77286768e-01  2.61826832e-02 -6.13144673e-02]
   [ 4.92131114e-02 -1.68546677e-01 -9.05454457e-02]
   [ 1.34218723e-01  9.24603790e-02  2.86343377e-02]]

  [[ 8.44602734e-02 -1.39303684e-01 -1.99598506e-01]
   [-2.01559380e-01 -3.02711338e-01 -3.98626208e-01]
   [-7.53270239e-02 -2.61433329e-02 -1.47371054e-01]]

  [[ 5.03057599e-01  2.24507883e-01 -1.79576054e-02]
   [ 4.66886275e-02 -1.50711626e-01 -1.47908300e-01]
   [ 9.68865678e-02  6.03277385e-02 -8.16721767e-02]]]


 [[[-2.75323629e-01 -2.87313581e-01 -2.56544143e-01]
   [-2.85371304e-01 -3.40805680e-01 -3.21976185e-01]
   [-1.65993527e-01 -3.29519987e-01 -2.41806850e-01]]

  [[ 2.43451834e-01  2.89894611e-01  2.48038486e-01]
   [ 2.45939761e-01  3.16391557e-01  1.92441121e-01]
   [ 2.52431542e-01  2.28074104e-01  2.37201765e-01]]

  [[ 1.48974136e-01 -6.42064214e-02 -2.66927518e-02]
   [ 8.97155553e-02  1.23831090e-02 -2.10826402e-03]
   [ 1.28007099e-01  2.02174

Epoch 61/100: 100%|██████████| 360/360 [00:35<00:00, 10.28it/s]


Epoch 61 - ELBO Loss: 7606.2616


Epoch 62/100: 100%|██████████| 360/360 [00:34<00:00, 10.30it/s]


Epoch 62 - ELBO Loss: 7655.4635


Epoch 63/100: 100%|██████████| 360/360 [00:35<00:00, 10.28it/s]


Epoch 63 - ELBO Loss: 7549.4525


Epoch 64/100: 100%|██████████| 360/360 [00:34<00:00, 10.31it/s]


Epoch 64 - ELBO Loss: 7534.7931


Epoch 65/100: 100%|██████████| 360/360 [00:34<00:00, 10.29it/s]


Epoch 65 - ELBO Loss: 7445.3891


Epoch 66/100: 100%|██████████| 360/360 [00:35<00:00, 10.28it/s]


Epoch 66 - ELBO Loss: 7479.9352


Epoch 67/100: 100%|██████████| 360/360 [00:34<00:00, 10.31it/s]


Epoch 67 - ELBO Loss: 7436.1365


Epoch 68/100: 100%|██████████| 360/360 [00:35<00:00, 10.27it/s]


Epoch 68 - ELBO Loss: 7543.2520


Epoch 69/100: 100%|██████████| 360/360 [00:34<00:00, 10.29it/s]


Epoch 69 - ELBO Loss: 7527.7921


Epoch 70/100: 100%|██████████| 360/360 [00:34<00:00, 10.30it/s]


Epoch 70 - ELBO Loss: 7312.2797


Acc check epoch 70: 100%|██████████| 360/360 [00:12<00:00, 29.55it/s]


Epoch 70 - Train Acc: 43.96%
AutoNormal.locs.conv1.weight: [[[[ 2.92983890e-01  1.91803761e-02 -4.54647690e-02]
   [ 7.63930306e-02 -1.66320041e-01 -1.07170746e-01]
   [ 1.57371581e-01  1.01670139e-01  4.25681360e-02]]

  [[ 9.09862518e-02 -1.19153656e-01 -1.92921013e-01]
   [-1.81793660e-01 -2.84339339e-01 -3.60011637e-01]
   [-6.64672405e-02  2.97077000e-03 -1.46614373e-01]]

  [[ 5.76857567e-01  2.75959134e-01  1.41587518e-02]
   [ 8.25259164e-02 -1.91076428e-01 -1.47116885e-01]
   [ 9.41896662e-02  9.24264193e-02  1.48049765e-03]]]


 [[[-2.90827334e-01 -3.00577313e-01 -2.05406770e-01]
   [-3.13359231e-01 -3.76629680e-01 -3.04398179e-01]
   [-1.79686040e-01 -3.67836565e-01 -2.35621929e-01]]

  [[ 2.72909373e-01  3.29404354e-01  3.42770875e-01]
   [ 2.65614659e-01  3.33690286e-01  2.81939805e-01]
   [ 3.00472081e-01  2.65341789e-01  2.85944641e-01]]

  [[ 1.69261113e-01 -6.63651451e-02  9.42674873e-04]
   [ 9.42107961e-02  3.35277431e-02  3.42075538e-04]
   [ 1.36747703e-01  2.96401

Epoch 71/100: 100%|██████████| 360/360 [00:35<00:00, 10.28it/s]


Epoch 71 - ELBO Loss: 7299.2702


Epoch 72/100: 100%|██████████| 360/360 [00:35<00:00, 10.27it/s]


Epoch 72 - ELBO Loss: 7298.3140


Epoch 73/100: 100%|██████████| 360/360 [00:35<00:00, 10.23it/s]


Epoch 73 - ELBO Loss: 7429.7168


Epoch 74/100: 100%|██████████| 360/360 [00:35<00:00, 10.27it/s]


Epoch 74 - ELBO Loss: 7220.4759


Epoch 75/100: 100%|██████████| 360/360 [00:34<00:00, 10.29it/s]


Epoch 75 - ELBO Loss: 7342.7733


Epoch 76/100: 100%|██████████| 360/360 [00:34<00:00, 10.31it/s]


Epoch 76 - ELBO Loss: 7357.7769


Epoch 77/100: 100%|██████████| 360/360 [00:35<00:00, 10.24it/s]


Epoch 77 - ELBO Loss: 7372.3968


Epoch 78/100: 100%|██████████| 360/360 [00:35<00:00, 10.22it/s]


Epoch 78 - ELBO Loss: 7363.2901


Epoch 79/100: 100%|██████████| 360/360 [00:34<00:00, 10.31it/s]


Epoch 79 - ELBO Loss: 7424.8474


Epoch 80/100: 100%|██████████| 360/360 [00:35<00:00, 10.28it/s]


Epoch 80 - ELBO Loss: 7353.7243


Acc check epoch 80: 100%|██████████| 360/360 [00:12<00:00, 29.49it/s]


Epoch 80 - Train Acc: 44.73%
AutoNormal.locs.conv1.weight: [[[[ 3.30273271e-01  4.14470397e-02 -6.98502809e-02]
   [ 4.90526706e-02 -1.85313627e-01 -1.48011163e-01]
   [ 1.52937561e-01  1.14776947e-01  7.49906078e-02]]

  [[ 6.77817911e-02 -1.59453705e-01 -2.43680522e-01]
   [-2.55362958e-01 -3.42126548e-01 -3.59763473e-01]
   [-1.43626168e-01 -3.50504704e-02 -1.49106696e-01]]

  [[ 5.78804910e-01  2.65747190e-01 -1.29209850e-02]
   [ 5.03849052e-02 -2.09954470e-01 -1.67557418e-01]
   [ 4.15119231e-02  6.66053221e-02 -6.00890722e-03]]]


 [[[-3.33551526e-01 -3.11150670e-01 -2.16817990e-01]
   [-3.38986963e-01 -4.52479571e-01 -3.08817118e-01]
   [-2.22650811e-01 -4.03290778e-01 -2.65730470e-01]]

  [[ 3.14382911e-01  3.18810880e-01  3.48406702e-01]
   [ 3.27318281e-01  3.46518874e-01  3.19575518e-01]
   [ 3.56499583e-01  2.87236899e-01  3.32367331e-01]]

  [[ 1.48715988e-01 -9.15033072e-02 -5.50923264e-03]
   [ 4.11932096e-02  5.26468316e-03 -1.29091023e-02]
   [ 8.10768455e-02  7.81991

Epoch 81/100: 100%|██████████| 360/360 [00:34<00:00, 10.30it/s]


Epoch 81 - ELBO Loss: 7428.3484


Epoch 82/100: 100%|██████████| 360/360 [00:35<00:00, 10.26it/s]


Epoch 82 - ELBO Loss: 7367.0592


Epoch 83/100: 100%|██████████| 360/360 [00:34<00:00, 10.31it/s]


Epoch 83 - ELBO Loss: 7209.2962


Epoch 84/100: 100%|██████████| 360/360 [00:34<00:00, 10.29it/s]


Epoch 84 - ELBO Loss: 7417.0035


Epoch 85/100: 100%|██████████| 360/360 [1:04:11<00:00, 10.70s/it]    


Epoch 85 - ELBO Loss: 7253.6508


Epoch 86/100: 100%|██████████| 360/360 [00:13<00:00, 25.81it/s]


Epoch 86 - ELBO Loss: 7442.4430


Epoch 87/100: 100%|██████████| 360/360 [00:09<00:00, 37.67it/s]


Epoch 87 - ELBO Loss: 7362.7684


Epoch 88/100: 100%|██████████| 360/360 [00:09<00:00, 37.95it/s]


Epoch 88 - ELBO Loss: 7387.7770


Epoch 89/100: 100%|██████████| 360/360 [00:10<00:00, 35.02it/s]


Epoch 89 - ELBO Loss: 7315.0952


Epoch 90/100: 100%|██████████| 360/360 [00:11<00:00, 30.86it/s]


Epoch 90 - ELBO Loss: 7377.5229


Acc check epoch 90:  76%|███████▌  | 274/360 [00:05<00:01, 68.12it/s]