In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch

import pickle

In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pandas as pd

In [4]:
device = torch.device("cuda")

In [5]:
w_star = torch.load("w_star.pth")

In [6]:
class DeterministicCNNSingleFCRelu(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.gap = nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = nn.Linear(64, num_classes)

    def forward(self, x):
        # x: [B, 3, 64, 64]
        x = self.pool(F.relu(self.conv1(x)))  # → [B, 32, 32, 32]
        x = self.pool(F.relu(self.conv2(x)))  # → [B, 64, 16, 16]
        x = self.gap(x)                       # → [B, 64, 1, 1]
        x = x.view(x.size(0), -1)             # → [B, 64]
        logits = self.fc1(x)                  # → [B, num_classes]
        return logits

In [7]:
deterministic_cnn = DeterministicCNNSingleFCRelu(num_classes=10).to(device)

In [8]:
# helper to invert softplus
def inv_softplus(x):
    return x + torch.log(-torch.expm1(-x))

# small initial posterior std
init_eps = 0.05
raw_init = inv_softplus(torch.tensor(init_eps))

# --- 2) Define your guide using warm-started variational params ---
def guide(x, y=None):
    # Plate over parameters
    for name, param in deterministic_cnn.named_parameters():
        # variational mean = pretrained weight
        loc = pyro.param(f"{name}_loc",
                         lambda: w_star[name])
        
        # variational raw scale so that softplus(raw_scale) ≈ init_eps
        scale_unconstrained = pyro.param(f"{name}_scale_unconstrained",
                                         lambda: torch.full_like(w_star[name], raw_init))
        scale = F.softplus(scale_unconstrained)
        
        # sample a weight tensor from Normal(loc, scale)
        # to_event(param.dim()) ensures we treat all dims as batch dims
        pyro.sample(name,
                    dist.Normal(loc, scale)
                        .to_event(param.dim()))

In [9]:
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

class BayesianCNNSingleFC(PyroModule):
    def __init__(
        self,
        num_classes,
        device,
        activation='wg',
        prior_dist='gaussian',
        mu=0,
        b=10.0,
        prior_params=None
    ):
        super().__init__()

        # Store device
        self.device = device

        # Activation setup: accept string or callable
        if isinstance(activation, str):
            act_map = {
                'relu': F.relu,
                'tanh': F.tanh,
                'wg': self.actWG,
                'rwg': self.actRWG,
                'sigmoid': F.sigmoid,
                'sinusoidal': torch.sin,
            }
            try:
                self.activation_fn = act_map[activation]
            except KeyError:
                raise ValueError(f"Unsupported activation: {activation}")
        elif callable(activation):
            self.activation_fn = activation
        else:
            raise ValueError("activation must be a string or callable")

        # Prior distribution setup
        self.prior_dist = prior_dist
        # Set default prior parameters if not provided
        default_params = {'mu': 0.0, 'b': 10.0}
        params = default_params if prior_params is None else prior_params
        self.prior_mu = torch.tensor(params.get('mu', params['mu']), device=device)
        self.prior_b  = torch.tensor(params.get('b', params['b']), device=device)

        print(f"Using prior distribution: {self.prior_dist} with mu={self.prior_mu.item()} and b={self.prior_b.item()}")

        # Layer definitions with priors
        self.conv1 = PyroModule[nn.Conv2d](3, 32, kernel_size=5, stride=1, padding=2)
        self.conv1.weight = PyroSample(self._make_prior([32, 3, 5, 5]))
        self.conv1.bias   = PyroSample(self._make_prior([32]))

        self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=5, stride=1, padding=2)
        self.conv2.weight = PyroSample(self._make_prior([64, 32, 5, 5]))
        self.conv2.bias   = PyroSample(self._make_prior([64]))

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.gap  = nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = PyroModule[nn.Linear](64, num_classes)
        self.fc1.weight = PyroSample(self._make_prior([num_classes, 64]))
        self.fc1.bias   = PyroSample(self._make_prior([num_classes]))

    def actWG(self, x, alpha=1.0):
        # Weight-gradient activation
        return x * torch.exp(-alpha * x ** 2)
    
    def actRWG(self, x, alpha=1.0):
        # Weight-gradient activation
        return max(0,x * torch.exp(-alpha * x ** 2))

    def _make_prior(self, shape):
        """
        Construct a prior distribution based on self.prior_dist and parameters.
        """
        if self.prior_dist == 'gaussian':
            base = dist.Normal(self.prior_mu, self.prior_b)
        elif self.prior_dist == 'laplace':
            base = dist.Laplace(self.prior_mu, self.prior_b)
        elif self.prior_dist == 'uniform':
            # uniform over [-b, b]
            base = dist.Uniform(-self.prior_b, self.prior_b)
        else:
            raise ValueError(f"Unsupported prior distribution: {self.prior_dist}")
        return base.expand(shape).to_event(len(shape))

    def forward(self, x, y=None):
        # x: [B, 3, 64, 64]
        x = self.activation_fn(self.conv1(x).to(self.device))
        x = self.pool(x)
        x = self.activation_fn(self.conv2(x))
        x = self.pool(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        logits = self.fc1(x)

        if y is not None:
            with pyro.plate("data", x.size(0)):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)

        return logits

In [10]:
eurosat_mean = [0.344, 0.380, 0.408]
eurosat_std  = [0.190, 0.137, 0.115]

old_mean = [0.3444, 0.3803, 0.4078]
old_std = [0.0914, 0.0651, 0.0552]


def load_data(batch_size=54):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=eurosat_mean, 
                             std=eurosat_std)
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=True)

    torch.manual_seed(42)

    #train_size = int(0.8 * len(dataset))
    #test_size = len(dataset) - train_size
    #train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

In [29]:
# training SVI function

import os
import torch
import pyro
from tqdm import tqdm
import numpy as np

def train_svi_with_stats(
    model,
    guide,
    svi,
    train_loader,
    device,
    num_epochs=10,
    save_epochs=None,
    save_dir='results',
    model_filename_pattern='model_{activation}_{prior}_epoch_{epoch}_{timestamp}.pth',
    guide_filename_pattern='guide_{activation}_{prior}_epoch_{epoch}_{timestamp}.pth',
    param_store_filename_pattern='param_store_{activation}_{prior}_epoch_{epoch}_{timestamp}.pkl',
    accuracies_filename_pattern='accuracy_results_{activation}_{prior}_{timestamp}.csv',
    losses_filename_pattern='losses_{activation}_{prior}_{timestamp}.csv',
    model_config_filename_pattern='config_{activation}_{prior}_{timestamp}.json'
):
    """
    Train the SVI model, track losses/accuracies, and
    save artifacts only when accuracy improves, naming files
    like `model_relu_gaussian_epoch_3.pth`.
    """
    
    # Pull names off the model if available, else fall back
    #act_name  = getattr(model, 'activation', getattr(model, 'activation_name', 'act'))
    act_name = model.activation_fn.__name__ if hasattr(model.activation_fn, '__name__') else str(model.activation_fn)
    prior_name = getattr(model, 'prior_dist', 'prior')
    timestamp = time.strftime("%Y%m%d_%H%M%S")

    os.makedirs(save_dir, exist_ok=True)
    save_epochs = set(save_epochs or range(1, num_epochs+1))

    pyro.clear_param_store()
    model.to(device)

    epoch_losses, epoch_accuracies, accuracy_epochs = [], [], []
    weight_stats = {'epochs': [], 'means': [], 'stds': []}
    bias_stats   = {'epochs': [], 'means': [], 'stds': []}
    best_acc = 0.0

    for epoch in range(1, num_epochs+1):
        model.train()
        total_loss = 0.0
        batches = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            total_loss += svi.step(images, labels)
            batches += 1

        avg_loss = total_loss / batches
        epoch_losses.append(avg_loss)
        print(f"Epoch {epoch} - ELBO Loss: {avg_loss:.4f}")

        if epoch == 1 or epoch % 10 == 0 or epoch == num_epochs:
            model.eval(); #guide.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for images, labels in tqdm(train_loader, desc=f"Acc check epoch {epoch}"):
                    images, labels = images.to(device), labels.to(device)
                    trace = pyro.poutine.trace(guide).get_trace(images)
                    replayed = pyro.poutine.replay(model, trace=trace)
                    logits = replayed(images)
                    preds = torch.argmax(logits, dim=1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)

            acc = correct/total
            epoch_accuracies.append(acc); accuracy_epochs.append(epoch)
            print(f"Epoch {epoch} - Train Acc: {acc*100:.2f}%")

            # record stats...
            w_means, w_stds, b_means, b_stds = [], [], [], []
            for name, param in pyro.get_param_store().items():
                if 'loc' in name:
                    w_means.append(param.mean().item()); w_stds.append(param.std().item())
                elif 'scale' in name:
                    b_means.append(param.mean().item()); b_stds.append(param.std().item())
            weight_stats['epochs'].append(epoch)
            weight_stats['means'].append(w_means)
            weight_stats['stds'].append(w_stds)
            bias_stats['epochs'].append(epoch)
            bias_stats['means'].append(b_means)
            bias_stats['stds'].append(b_stds)

            # only save when accuracy improves
            if acc > best_acc:
                best_acc = acc
                fname_model = model_filename_pattern.format(activation=act_name, prior=prior_name, epoch="best", timestamp=timestamp)
                fname_guide = guide_filename_pattern.format(activation=act_name, prior=prior_name, epoch="best", timestamp=timestamp)
                fname_ps    = param_store_filename_pattern.format(activation=act_name, prior=prior_name, epoch="best", timestamp=timestamp)

                torch.save(model.state_dict(), os.path.join(save_dir, fname_model))
                #torch.save(guide.state_dict(), os.path.join(save_dir, fname_guide))
                pyro.get_param_store().save(os.path.join(save_dir, fname_ps))
                print(f"  ↳ Saved: {fname_model}, {fname_guide}, {fname_ps}")

    # save losses per epoch in a csv file, with consistent file naming
    accuracies_df = pd.DataFrame({
        'epoch': accuracy_epochs,
        'accuracy': epoch_accuracies
    })
    accuracies_df.to_csv(os.path.join(save_dir,accuracies_filename_pattern.format(activation=act_name, prior=prior_name, timestamp=timestamp)), index=False)

    loss_df = pd.DataFrame({
        'epoch': list(range(1, epoch + 1)),
        'loss': epoch_losses
    })
    loss_df.to_csv(os.path.join(save_dir,losses_filename_pattern.format(activation=act_name, prior=prior_name, timestamp=timestamp)), index=False)
            
    # save model configuration in a json file
    config = {
        'activation': act_name,
        'prior': prior_name,
        'num_epochs': num_epochs,
        'best_accuracy_at_epoch': accuracy_epochs[np.argmax(epoch_accuracies)],
        'best_accuracy': best_acc,
        'batch_size': train_loader.batch_size,
        'train_size': len(train_loader.dataset),
        'prior_params': {
            'mu': model.prior_mu.item(),
            'b': model.prior_b.item()
        },
    }
    config_filename = model_config_filename_pattern.format(activation=act_name, prior=prior_name, timestamp=timestamp)

    with open(os.path.join(save_dir, config_filename), 'w') as f:
        import json
        json.dump(config, f, indent=4)
        print(f"Configuration saved to {config_filename}")

    return epoch_losses, epoch_accuracies, accuracy_epochs, weight_stats, bias_stats, os.path.join(save_dir, fname_model), os.path.join(save_dir, fname_guide), os.path.join(save_dir, fname_ps), timestamp


In [30]:
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam
from tqdm import tqdm
import pandas as pd

In [None]:
num_classes = 10

bayesian_model = BayesianCNNSingleFC(num_classes,
        device,
        activation='relu',
        prior_dist='gaussian',
        prior_params={'mu': 0.0, 'b': 10.0})

#act_name  = getattr(bayesian_model, 'activation', getattr(bayesian_model, 'activation', 'act'))
#prior_name = getattr(bayesian_model, 'prior_dist', 'prior')

#guide = AutoDiagonalNormal(bayesian_model)

optimizer = Adam({"lr": 1e-3})  # Increased from 1e-4 to 1e-3
svi = pyro.infer.SVI(model=bayesian_model,
                     guide=guide,
                     optim=optimizer,
                     loss=pyro.infer.Trace_ELBO(num_particles=1,
                                                )) #TODO


Using prior distribution: gaussian with mu=0.0 and b=1


In [33]:
pyro.clear_param_store()

# Ensure model and guide are on the correct device
bayesian_model.to(device)
#guide.to(device)

train_loader, test_loader = load_data(batch_size=54)

In [34]:
losses, accuracies, accuracy_epochs, weight_stats, bias_stats, best_model_path, best_guide_path, best_param_store_path, experiment_timestamp = train_svi_with_stats(
    bayesian_model,
    guide,
    svi,
    train_loader,
    device,
    num_epochs=100,
    save_epochs=None,
    save_dir='results_GP_eurosat')

Epoch 1/100:   0%|          | 0/360 [00:00<?, ?it/s]

Epoch 1/100: 100%|██████████| 360/360 [00:08<00:00, 40.87it/s]


Epoch 1 - ELBO Loss: 126417.0818


Acc check epoch 1: 100%|██████████| 360/360 [00:04<00:00, 82.87it/s]


Epoch 1 - Train Acc: 62.63%
  ↳ Saved: model_relu_gaussian_epoch_best_20250714_181853.pth, guide_relu_gaussian_epoch_best_20250714_181853.pth, param_store_relu_gaussian_epoch_best_20250714_181853.pkl


Epoch 2/100: 100%|██████████| 360/360 [00:08<00:00, 41.20it/s]


Epoch 2 - ELBO Loss: 107729.7695


Epoch 3/100: 100%|██████████| 360/360 [00:08<00:00, 42.35it/s]


Epoch 3 - ELBO Loss: 89683.3380


Epoch 4/100: 100%|██████████| 360/360 [00:08<00:00, 41.81it/s]


Epoch 4 - ELBO Loss: 72428.2660


Epoch 5/100: 100%|██████████| 360/360 [00:08<00:00, 41.56it/s]


Epoch 5 - ELBO Loss: 56381.6261


Epoch 6/100: 100%|██████████| 360/360 [00:08<00:00, 41.44it/s]


Epoch 6 - ELBO Loss: 42017.6050


Epoch 7/100: 100%|██████████| 360/360 [00:08<00:00, 41.41it/s]


Epoch 7 - ELBO Loss: 29938.2612


Epoch 8/100: 100%|██████████| 360/360 [00:08<00:00, 41.69it/s]


Epoch 8 - ELBO Loss: 20575.4683


Epoch 9/100: 100%|██████████| 360/360 [00:08<00:00, 41.93it/s]


Epoch 9 - ELBO Loss: 13937.2326


Epoch 10/100: 100%|██████████| 360/360 [00:08<00:00, 42.49it/s]


Epoch 10 - ELBO Loss: 9635.0922


Acc check epoch 10: 100%|██████████| 360/360 [00:04<00:00, 82.88it/s]


Epoch 10 - Train Acc: 13.60%


Epoch 11/100: 100%|██████████| 360/360 [00:08<00:00, 41.46it/s]


Epoch 11 - ELBO Loss: 6947.3098


Epoch 12/100: 100%|██████████| 360/360 [00:08<00:00, 41.97it/s]


Epoch 12 - ELBO Loss: 5265.9073


Epoch 13/100: 100%|██████████| 360/360 [00:08<00:00, 41.47it/s]


Epoch 13 - ELBO Loss: 4193.3404


Epoch 14/100: 100%|██████████| 360/360 [00:08<00:00, 41.38it/s]


Epoch 14 - ELBO Loss: 3516.0140


Epoch 15/100: 100%|██████████| 360/360 [00:08<00:00, 41.61it/s]


Epoch 15 - ELBO Loss: 3041.4008


Epoch 16/100: 100%|██████████| 360/360 [00:08<00:00, 42.48it/s]


Epoch 16 - ELBO Loss: 2692.3384


Epoch 17/100: 100%|██████████| 360/360 [00:08<00:00, 42.42it/s]


Epoch 17 - ELBO Loss: 2484.7237


Epoch 18/100: 100%|██████████| 360/360 [00:08<00:00, 41.51it/s]


Epoch 18 - ELBO Loss: 2300.8506


Epoch 19/100: 100%|██████████| 360/360 [00:08<00:00, 41.92it/s]


Epoch 19 - ELBO Loss: 2159.9345


Epoch 20/100: 100%|██████████| 360/360 [00:08<00:00, 41.30it/s]


Epoch 20 - ELBO Loss: 2045.9729


Acc check epoch 20: 100%|██████████| 360/360 [00:04<00:00, 83.32it/s]


Epoch 20 - Train Acc: 13.16%


Epoch 21/100: 100%|██████████| 360/360 [00:08<00:00, 42.03it/s]


Epoch 21 - ELBO Loss: 2019.7566


Epoch 22/100: 100%|██████████| 360/360 [00:08<00:00, 42.27it/s]


Epoch 22 - ELBO Loss: 1923.2202


Epoch 23/100: 100%|██████████| 360/360 [00:08<00:00, 41.34it/s]


Epoch 23 - ELBO Loss: 1885.1080


Epoch 24/100: 100%|██████████| 360/360 [00:08<00:00, 42.08it/s]


Epoch 24 - ELBO Loss: 1846.9597


Epoch 25/100: 100%|██████████| 360/360 [00:08<00:00, 41.98it/s]


Epoch 25 - ELBO Loss: 1795.3861


Epoch 26/100: 100%|██████████| 360/360 [00:08<00:00, 42.25it/s]


Epoch 26 - ELBO Loss: 1769.7429


Epoch 27/100: 100%|██████████| 360/360 [00:08<00:00, 41.86it/s]


Epoch 27 - ELBO Loss: 1729.6698


Epoch 28/100: 100%|██████████| 360/360 [00:08<00:00, 42.46it/s]


Epoch 28 - ELBO Loss: 1708.6824


Epoch 29/100: 100%|██████████| 360/360 [00:08<00:00, 42.32it/s]


Epoch 29 - ELBO Loss: 1678.1144


Epoch 30/100: 100%|██████████| 360/360 [00:08<00:00, 41.95it/s]


Epoch 30 - ELBO Loss: 1692.9943


Acc check epoch 30: 100%|██████████| 360/360 [00:04<00:00, 83.18it/s]


Epoch 30 - Train Acc: 14.08%


Epoch 31/100:  94%|█████████▎| 337/360 [00:08<00:00, 40.35it/s]


KeyboardInterrupt: 

In [None]:
def plot_training_results_with_stats(losses, accuracies, accuracy_epochs, weight_stats, bias_stats):
    """Plot training results with weight and bias statistics"""
    plt.figure(figsize=(16, 12))
    
    # Plot 1: Training Loss
    plt.subplot(2, 2, 1)
    plt.plot(range(1, len(losses) + 1), losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('ELBO Loss')
    plt.grid(True)
    
    # Plot 2: Training Accuracy
    plt.subplot(2, 2, 2)
    plt.plot(accuracy_epochs, accuracies, 'o-')
    plt.title('Training Accuracy (Every 10 Epochs)')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)
    
    # Plot 3: Weight Statistics Boxplot
    plt.subplot(2, 2, 3)
    weight_data = []
    weight_labels = []
    
    for i, epoch in enumerate(weight_stats['epochs']):
        # Combine means and stds for this epoch
        epoch_data = weight_stats['means'][i] + weight_stats['stds'][i]
        weight_data.append(epoch_data)
        weight_labels.append(f'Epoch {epoch}')
    
    if weight_data:
        bp1 = plt.boxplot(weight_data, labels=weight_labels, patch_artist=True)
        for patch in bp1['boxes']:
            patch.set_facecolor('lightblue')
    
    plt.title('LOC Statistics Distribution')
    plt.xlabel('Epoch')
    plt.ylabel('LOC Values')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Bias Statistics Boxplot
    plt.subplot(2, 2, 4)
    bias_data = []
    bias_labels = []
    
    for i, epoch in enumerate(bias_stats['epochs']):
        # Combine means and stds for this epoch
        epoch_data = bias_stats['means'][i] + bias_stats['stds'][i]
        bias_data.append(epoch_data)
        bias_labels.append(f'Epoch {epoch}')
    
    if bias_data:
        bp2 = plt.boxplot(bias_data, tick_labels=bias_labels, patch_artist=True)
        for patch in bp2['boxes']:
            patch.set_facecolor('lightcoral')
    
    plt.title('SCALE Statistics Distribution')
    plt.xlabel('Epoch')
    plt.ylabel('SCALE Values')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
plot_training_results_with_stats(losses, accuracies, accuracy_epochs, weight_stats, bias_stats)

In [None]:
# clear the parameter store and reload the parameter store from the best result
pyro.clear_param_store()

bayesian_model.load_state_dict(torch.load(best_model_path))
guide.load_state_dict(torch.load(best_guide_path))
pyro.get_param_store().set_state(torch.load(best_param_store_path,weights_only=False))

In [None]:
# print confusion matrix
import numpy as np
from sklearn.metrics import confusion_matrix


def predict_data(model, loader_of_interest, num_samples=10):
    model.eval()
    #guide.eval()

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in tqdm(loader_of_interest, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

            for i in range(num_samples):
                guide_trace = pyro.poutine.trace(guide).get_trace(images)
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    return all_labels, all_predictions

In [None]:
all_labels, all_predictions = predict_data(bayesian_model, test_loader, num_samples=10)
cm = confusion_matrix(all_labels, all_predictions)
#print accuracy from confusion matrix
accuracy = np.trace(cm) / np.sum(cm)
print(f"Accuracy from confusion matrix: {accuracy * 100:.6f}%")

In [None]:
df = pd.DataFrame({'True Label': all_labels, 'Predicted Label': all_predictions})

In [None]:
#save the prediction label and true label in a csv file

def save_predictions_to_csv(labels, predictions, filename='predictions.csv'):
    df = pd.DataFrame({'True Label': labels, 'Predicted Label': predictions})
    df.to_csv(filename, index=False)
    print(f"Predictions saved to {filename}")

In [None]:
act_name = bayesian_model.activation_fn.__name__ if hasattr(bayesian_model.activation_fn, '__name__') else str(bayesian_model.activation_fn)
prior_name = getattr(bayesian_model, 'prior_dist', 'prior')

save_predictions_to_csv(all_labels, all_predictions, os.path.join('results_GP_eurosat', f'predictions_{act_name}_{prior_name}_{experiment_timestamp}_{accuracy * 100:.0f}.csv'))