In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch

import pickle

In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from tqdm import tqdm

In [4]:
device = torch.device("cuda")

In [5]:
from utils.model import BayesianCNNSingleFCCustomWGBN

In [None]:
def fixed_load_data(batch_size=54):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1809, 0.1331, 0.1137])
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=False)

    torch.manual_seed(42)
    
    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

In [6]:
def load_data(batch_size=54):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1809, 0.1331, 0.1137])
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=False)

    torch.manual_seed(42)

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    #with open('datasplit/split_indices.pkl', 'rb') as f:
    #    split = pickle.load(f)
    #    train_dataset = Subset(dataset, split['train'])
    #    test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

In [7]:
num_classes = 10
bayesian_model = BayesianCNNSingleFCCustomWGBN(num_classes=num_classes, 
                                             mu=0,
                                             sigma=10.,
                                             device=device)

In [8]:
bayesian_model

BayesianCNNSingleFCCustomWGBN(
  (conv1): PyroConv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): PyroConv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): PyroLinear(in_features=16384, out_features=10, bias=True)
)

In [9]:
from pyro.infer.autoguide import AutoDiagonalNormal
#from pyro.infer.autoguide import AutoLowRankMultivariateNormal
from pyro.optim import Adam

In [10]:
guide = AutoDiagonalNormal(bayesian_model)
#guide = AutoLowRankMultivariateNormal(bayesian_model, rank=10)

# 2. Optimizer and SVI - increase learning rate for better convergence
optimizer = Adam({"lr": 1e-3})  # Increased from 1e-4 to 1e-3
svi = pyro.infer.SVI(model=bayesian_model,
                     guide=guide,
                     optim=optimizer,
                     loss=pyro.infer.Trace_ELBO())

In [11]:
from tqdm import tqdm

In [None]:
def train_svi(model, guide, svi, train_loader, num_epochs=10):
    # Clear parameter store only ONCE at the beginning
    pyro.clear_param_store()
    model.train()
    
    # Ensure model is on the correct device
    model.to(device)
    #guide.to(device)
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            loss = svi.step(images, labels)
            epoch_loss += loss
            num_batches += 1
            
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}")

In [None]:
def train_svi(model, guide, svi, train_loader, num_epochs=10):
    # Clear parameter store only ONCE at the beginning
    pyro.clear_param_store()
    model.train()
    
    # Ensure model is on the correct device
    model.to(device)
    #guide.to(device)
    
    # Lists to store losses and accuracies
    epoch_losses = []
    epoch_accuracies = []
    accuracy_epochs = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            loss = svi.step(images, labels)
            epoch_loss += loss
            num_batches += 1
            
        avg_loss = epoch_loss / num_batches
        epoch_losses.append(avg_loss)
        
        # Calculate accuracy every 10 epochs (and on the first and last epoch)
        if (epoch + 1) % 10 == 0 or epoch == 0 or epoch == num_epochs - 1:
            model.eval()
            guide.eval()
            
            correct_predictions = 0
            total_samples = 0
            
            with torch.no_grad():
                for images, labels in tqdm(train_loader, desc=f"Calculating accuracy for epoch {epoch+1}"):
                    images, labels = images.to(device), labels.to(device)
                    
                    # Sample from the guide to get model parameters
                    guide_trace = pyro.poutine.trace(guide).get_trace(images)
                    replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                    
                    # Get predictions
                    logits = replayed_model(images)
                    predictions = torch.argmax(logits, dim=1)
                    
                    # Count correct predictions
                    correct_predictions += (predictions == labels).sum().item()
                    total_samples += labels.size(0)
            
            epoch_accuracy = correct_predictions / total_samples
            epoch_accuracies.append(epoch_accuracy)
            accuracy_epochs.append(epoch + 1)
            
            model.train()  # Set back to training mode
            
            print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}, Train Accuracy: {epoch_accuracy*100:.2f}%")
        else:
            print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}")
    
    return epoch_losses, epoch_accuracies, accuracy_epochs

In [None]:
def train_svi_with_annealing(model, guide, svi, train_loader, num_epochs=10):
    pyro.clear_param_store()
    model.train()
    model.to(device)
    
    epoch_losses = []
    epoch_accuracies = []
    accuracy_epochs = []
    
    for epoch in range(num_epochs):
        # KL annealing - gradually increase KL weight
        kl_weight = min(1.0, (epoch + 1) / (num_epochs * 0.5))  # Reach full weight at 50% of training
        
        epoch_loss = 0.0
        num_batches = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            # Use weighted ELBO
            def weighted_model(images, labels):
                with pyro.poutine.scale(scale=kl_weight):
                    return model(images, labels)
            
            loss = svi.step(images, labels)
            epoch_loss += loss
            num_batches += 1
            
        avg_loss = epoch_loss / num_batches
        epoch_losses.append(avg_loss)

        # Calculate accuracy every 10 epochs (and on the first and last epoch)
        if (epoch + 1) % 10 == 0 or epoch == 0 or epoch == num_epochs - 1:
            model.eval()
            guide.eval()
            
            correct_predictions = 0
            total_samples = 0
            
            with torch.no_grad():
                for images, labels in tqdm(train_loader, desc=f"Calculating accuracy for epoch {epoch+1}"):
                    images, labels = images.to(device), labels.to(device)
                    
                    # Sample from the guide to get model parameters
                    guide_trace = pyro.poutine.trace(guide).get_trace(images)
                    replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                    
                    # Get predictions
                    logits = replayed_model(images)
                    predictions = torch.argmax(logits, dim=1)
                    
                    # Count correct predictions
                    correct_predictions += (predictions == labels).sum().item()
                    total_samples += labels.size(0)
            
            epoch_accuracy = correct_predictions / total_samples
            epoch_accuracies.append(epoch_accuracy)
            accuracy_epochs.append(epoch + 1)
            
            model.train()  # Set back to training mode
            
            print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}, Train Accuracy: {epoch_accuracy*100:.2f}%")
        else:
            print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}")
    
    return epoch_losses, epoch_accuracies, accuracy_epochs

In [None]:
"""
pyro.clear_param_store()

# Ensure model and guide are on the correct device
bayesian_model.to(device)
guide.to(device)

train_loader, test_loader = load_data(batch_size=54)
losses, accuracies, accuracy_epochs = train_svi(bayesian_model, guide, svi, train_loader, num_epochs=100)

# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(range(1, len(losses) + 1), losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('ELBO Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(accuracy_epochs, accuracies, 'o-')
plt.title('Training Accuracy (Every 10 Epochs)')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True)

plt.tight_layout()
plt.show()
"""

In [12]:
def train_svi_with_stats(model, guide, svi, train_loader, num_epochs=10):
    # Clear parameter store only ONCE at the beginning
    pyro.clear_param_store()
    model.train()
    
    # Ensure model is on the correct device
    model.to(device)
    
    # Lists to store losses and accuracies
    epoch_losses = []
    epoch_accuracies = []
    accuracy_epochs = []
    
    # Lists to store weight and bias statistics
    weight_stats = {'epochs': [], 'means': [], 'stds': []}
    bias_stats = {'epochs': [], 'means': [], 'stds': []}
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            loss = svi.step(images, labels)
            epoch_loss += loss
            num_batches += 1
            
        avg_loss = epoch_loss / num_batches
        epoch_losses.append(avg_loss)
        
        # Calculate accuracy every 10 epochs (and on the first and last epoch)
        if (epoch + 1) % 10 == 0 or epoch == 0 or epoch == num_epochs - 1:
            model.eval()
            guide.eval()
            
            correct_predictions = 0
            total_samples = 0
            
            with torch.no_grad():
                for images, labels in tqdm(train_loader, desc=f"Calculating accuracy for epoch {epoch+1}"):
                    images, labels = images.to(device), labels.to(device)
                    
                    # Sample from the guide to get model parameters
                    guide_trace = pyro.poutine.trace(guide).get_trace(images)
                    replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                    
                    # Get predictions
                    logits = replayed_model(images)
                    predictions = torch.argmax(logits, dim=1)
                    
                    # Count correct predictions
                    correct_predictions += (predictions == labels).sum().item()
                    total_samples += labels.size(0)
            
            epoch_accuracy = correct_predictions / total_samples
            epoch_accuracies.append(epoch_accuracy)
            accuracy_epochs.append(epoch + 1)
            
            # Record weight and bias statistics
            weight_means = []   # loc means
            weight_stds = []    # loc stds
            bias_means = []     # scale means
            bias_stds = []      # scale stds
            
            for name, param in pyro.get_param_store().items():
                if 'AutoDiagonalNormal.loc' in name:
                    weight_means.append(param.mean().item())
                    weight_stds.append(param.std().item())
                elif 'AutoDiagonalNormal.scale' in name:
                    bias_means.append(param.mean().item())
                    bias_stds.append(param.std().item())
            
            # Store statistics for this epoch
            weight_stats['epochs'].append(epoch + 1)
            weight_stats['means'].append(weight_means)
            weight_stats['stds'].append(weight_stds)
            
            bias_stats['epochs'].append(epoch + 1)
            bias_stats['means'].append(bias_means)
            bias_stats['stds'].append(bias_stds)
            
            model.train()  # Set back to training mode
            
            print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}, Train Accuracy: {epoch_accuracy*100:.2f}%")
        else:
            print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}")
    
    return epoch_losses, epoch_accuracies, accuracy_epochs, weight_stats, bias_stats

def plot_training_results_with_stats(losses, accuracies, accuracy_epochs, weight_stats, bias_stats):
    """Plot training results with weight and bias statistics"""
    plt.figure(figsize=(16, 12))
    
    # Plot 1: Training Loss
    plt.subplot(2, 2, 1)
    plt.plot(range(1, len(losses) + 1), losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('ELBO Loss')
    plt.grid(True)
    
    # Plot 2: Training Accuracy
    plt.subplot(2, 2, 2)
    plt.plot(accuracy_epochs, accuracies, 'o-')
    plt.title('Training Accuracy (Every 10 Epochs)')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)
    
    # Plot 3: Weight Statistics Boxplot
    plt.subplot(2, 2, 3)
    weight_data = []
    weight_labels = []
    
    for i, epoch in enumerate(weight_stats['epochs']):
        # Combine means and stds for this epoch
        epoch_data = weight_stats['means'][i] + weight_stats['stds'][i]
        weight_data.append(epoch_data)
        weight_labels.append(f'Epoch {epoch}')
    
    if weight_data:
        bp1 = plt.boxplot(weight_data, labels=weight_labels, patch_artist=True)
        for patch in bp1['boxes']:
            patch.set_facecolor('lightblue')
    
    plt.title('LOC Statistics Distribution')
    plt.xlabel('Epoch')
    plt.ylabel('LOC Values')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Bias Statistics Boxplot
    plt.subplot(2, 2, 4)
    bias_data = []
    bias_labels = []
    
    for i, epoch in enumerate(bias_stats['epochs']):
        # Combine means and stds for this epoch
        epoch_data = bias_stats['means'][i] + bias_stats['stds'][i]
        bias_data.append(epoch_data)
        bias_labels.append(f'Epoch {epoch}')
    
    if bias_data:
        bp2 = plt.boxplot(bias_data, tick_labels=bias_labels, patch_artist=True)
        for patch in bp2['boxes']:
            patch.set_facecolor('lightcoral')
    
    plt.title('SCALE Statistics Distribution')
    plt.xlabel('Epoch')
    plt.ylabel('SCALE Values')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [13]:
pyro.clear_param_store()

# Ensure model and guide are on the correct device
bayesian_model.to(device)
guide.to(device)

train_loader, test_loader = load_data(batch_size=54)

# Train with statistics recording
losses, accuracies, accuracy_epochs, weight_stats, bias_stats = train_svi_with_stats(
    bayesian_model, guide, svi, train_loader, num_epochs=100
)

Epoch 1/100: 100%|██████████| 400/400 [01:24<00:00,  4.73it/s]
Calculating accuracy for epoch 1: 100%|██████████| 400/400 [00:06<00:00, 66.57it/s]


Epoch 1 - ELBO Loss: 868147.0214, Train Accuracy: 43.39%


Epoch 2/100: 100%|██████████| 400/400 [00:19<00:00, 20.36it/s]


Epoch 2 - ELBO Loss: 788472.6365


Epoch 3/100: 100%|██████████| 400/400 [00:19<00:00, 20.38it/s]


Epoch 3 - ELBO Loss: 713378.3067


Epoch 4/100: 100%|██████████| 400/400 [00:19<00:00, 20.37it/s]


Epoch 4 - ELBO Loss: 643343.1743


Epoch 5/100: 100%|██████████| 400/400 [00:24<00:00, 16.38it/s]


Epoch 5 - ELBO Loss: 579188.1954


Epoch 6/100: 100%|██████████| 400/400 [00:16<00:00, 24.54it/s]


Epoch 6 - ELBO Loss: 521688.0792


Epoch 7/100: 100%|██████████| 400/400 [00:18<00:00, 21.27it/s]


Epoch 7 - ELBO Loss: 470959.7399


Epoch 8/100: 100%|██████████| 400/400 [00:15<00:00, 26.32it/s]


Epoch 8 - ELBO Loss: 427056.4175


Epoch 9/100: 100%|██████████| 400/400 [00:15<00:00, 26.04it/s]


Epoch 9 - ELBO Loss: 389429.0197


Epoch 10/100: 100%|██████████| 400/400 [00:15<00:00, 26.38it/s]
Calculating accuracy for epoch 10: 100%|██████████| 400/400 [00:04<00:00, 87.02it/s]


Epoch 10 - ELBO Loss: 357414.2469, Train Accuracy: 52.19%


Epoch 11/100: 100%|██████████| 400/400 [00:14<00:00, 28.06it/s]


Epoch 11 - ELBO Loss: 330198.8583


Epoch 12/100: 100%|██████████| 400/400 [00:14<00:00, 28.08it/s]


Epoch 12 - ELBO Loss: 306974.1289


Epoch 13/100: 100%|██████████| 400/400 [00:14<00:00, 27.88it/s]


Epoch 13 - ELBO Loss: 286958.5872


Epoch 14/100: 100%|██████████| 400/400 [00:14<00:00, 27.94it/s]


Epoch 14 - ELBO Loss: 269477.4280


Epoch 15/100: 100%|██████████| 400/400 [00:14<00:00, 27.98it/s]


Epoch 15 - ELBO Loss: 254036.2723


Epoch 16/100: 100%|██████████| 400/400 [00:14<00:00, 27.75it/s]


Epoch 16 - ELBO Loss: 240245.7856


Epoch 17/100: 100%|██████████| 400/400 [00:14<00:00, 27.78it/s]


Epoch 17 - ELBO Loss: 227770.3209


Epoch 18/100: 100%|██████████| 400/400 [00:14<00:00, 27.70it/s]


Epoch 18 - ELBO Loss: 216450.4898


Epoch 19/100: 100%|██████████| 400/400 [00:14<00:00, 27.65it/s]


Epoch 19 - ELBO Loss: 205947.3571


Epoch 20/100: 100%|██████████| 400/400 [00:14<00:00, 27.50it/s]
Calculating accuracy for epoch 20: 100%|██████████| 400/400 [00:04<00:00, 89.68it/s]


Epoch 20 - ELBO Loss: 196335.6437, Train Accuracy: 44.84%


Epoch 21/100: 100%|██████████| 400/400 [00:14<00:00, 27.49it/s]


Epoch 21 - ELBO Loss: 187375.3931


Epoch 22/100: 100%|██████████| 400/400 [00:14<00:00, 27.55it/s]


Epoch 22 - ELBO Loss: 179043.7936


Epoch 23/100: 100%|██████████| 400/400 [00:14<00:00, 27.61it/s]


Epoch 23 - ELBO Loss: 171363.4892


Epoch 24/100: 100%|██████████| 400/400 [00:14<00:00, 27.07it/s]


Epoch 24 - ELBO Loss: 164058.7753


Epoch 25/100: 100%|██████████| 400/400 [00:14<00:00, 27.13it/s]


Epoch 25 - ELBO Loss: 157260.7679


Epoch 26/100: 100%|██████████| 400/400 [00:14<00:00, 26.80it/s]


Epoch 26 - ELBO Loss: 150829.8054


Epoch 27/100: 100%|██████████| 400/400 [00:15<00:00, 25.96it/s]


Epoch 27 - ELBO Loss: 144695.0840


Epoch 28/100: 100%|██████████| 400/400 [00:15<00:00, 26.46it/s]


Epoch 28 - ELBO Loss: 139063.4146


Epoch 29/100: 100%|██████████| 400/400 [00:15<00:00, 25.75it/s]


Epoch 29 - ELBO Loss: 133691.4889


Epoch 30/100: 100%|██████████| 400/400 [00:15<00:00, 25.69it/s]
Calculating accuracy for epoch 30: 100%|██████████| 400/400 [00:04<00:00, 81.87it/s]


Epoch 30 - ELBO Loss: 128603.3019, Train Accuracy: 39.55%


Epoch 31/100: 100%|██████████| 400/400 [00:15<00:00, 26.02it/s]


Epoch 31 - ELBO Loss: 123863.4473


Epoch 32/100: 100%|██████████| 400/400 [00:15<00:00, 26.32it/s]


Epoch 32 - ELBO Loss: 119309.4188


Epoch 33/100: 100%|██████████| 400/400 [00:15<00:00, 26.27it/s]


Epoch 33 - ELBO Loss: 114938.7490


Epoch 34/100: 100%|██████████| 400/400 [00:15<00:00, 26.28it/s]


Epoch 34 - ELBO Loss: 110916.0608


Epoch 35/100: 100%|██████████| 400/400 [00:15<00:00, 25.62it/s]


Epoch 35 - ELBO Loss: 107159.8606


Epoch 36/100: 100%|██████████| 400/400 [00:15<00:00, 25.82it/s]


Epoch 36 - ELBO Loss: 103561.3871


Epoch 37/100: 100%|██████████| 400/400 [00:15<00:00, 26.32it/s]


Epoch 37 - ELBO Loss: 100107.3193


Epoch 38/100: 100%|██████████| 400/400 [00:15<00:00, 26.36it/s]


Epoch 38 - ELBO Loss: 96900.5514


Epoch 39/100: 100%|██████████| 400/400 [00:15<00:00, 26.03it/s]


Epoch 39 - ELBO Loss: 93821.8060


Epoch 40/100: 100%|██████████| 400/400 [00:15<00:00, 25.96it/s]
Calculating accuracy for epoch 40: 100%|██████████| 400/400 [00:04<00:00, 84.36it/s]


Epoch 40 - ELBO Loss: 90875.4116, Train Accuracy: 40.18%


Epoch 41/100: 100%|██████████| 400/400 [00:15<00:00, 26.03it/s]


Epoch 41 - ELBO Loss: 88109.0007


Epoch 42/100: 100%|██████████| 400/400 [00:15<00:00, 26.03it/s]


Epoch 42 - ELBO Loss: 85385.7669


Epoch 43/100: 100%|██████████| 400/400 [00:15<00:00, 26.06it/s]


Epoch 43 - ELBO Loss: 83043.1793


Epoch 44/100: 100%|██████████| 400/400 [00:16<00:00, 24.66it/s]


Epoch 44 - ELBO Loss: 80681.5722


Epoch 45/100: 100%|██████████| 400/400 [00:15<00:00, 25.57it/s]


Epoch 45 - ELBO Loss: 78399.7622


Epoch 46/100: 100%|██████████| 400/400 [00:15<00:00, 25.87it/s]


Epoch 46 - ELBO Loss: 76174.0023


Epoch 47/100: 100%|██████████| 400/400 [00:15<00:00, 25.54it/s]


Epoch 47 - ELBO Loss: 74285.0968


Epoch 48/100: 100%|██████████| 400/400 [00:18<00:00, 22.18it/s]


Epoch 48 - ELBO Loss: 72245.7851


Epoch 49/100: 100%|██████████| 400/400 [00:20<00:00, 19.18it/s]


Epoch 49 - ELBO Loss: 70297.2317


Epoch 50/100: 100%|██████████| 400/400 [00:15<00:00, 25.81it/s]
Calculating accuracy for epoch 50: 100%|██████████| 400/400 [00:04<00:00, 86.46it/s]


Epoch 50 - ELBO Loss: 68603.9202, Train Accuracy: 39.44%


Epoch 51/100: 100%|██████████| 400/400 [00:15<00:00, 25.99it/s]


Epoch 51 - ELBO Loss: 67001.9233


Epoch 52/100: 100%|██████████| 400/400 [00:15<00:00, 26.55it/s]


Epoch 52 - ELBO Loss: 65248.5050


Epoch 53/100: 100%|██████████| 400/400 [00:15<00:00, 25.41it/s]


Epoch 53 - ELBO Loss: 63662.8159


Epoch 54/100: 100%|██████████| 400/400 [00:15<00:00, 25.61it/s]


Epoch 54 - ELBO Loss: 62298.0877


Epoch 55/100: 100%|██████████| 400/400 [00:15<00:00, 25.65it/s]


Epoch 55 - ELBO Loss: 60811.8690


Epoch 56/100: 100%|██████████| 400/400 [00:15<00:00, 26.46it/s]


Epoch 56 - ELBO Loss: 59422.1528


Epoch 57/100: 100%|██████████| 400/400 [00:15<00:00, 26.21it/s]


Epoch 57 - ELBO Loss: 58161.6532


Epoch 58/100: 100%|██████████| 400/400 [00:15<00:00, 26.14it/s]


Epoch 58 - ELBO Loss: 56773.0168


Epoch 59/100: 100%|██████████| 400/400 [00:15<00:00, 26.26it/s]


Epoch 59 - ELBO Loss: 55604.9566


Epoch 60/100: 100%|██████████| 400/400 [00:15<00:00, 26.19it/s]
Calculating accuracy for epoch 60: 100%|██████████| 400/400 [00:04<00:00, 86.27it/s]


Epoch 60 - ELBO Loss: 54480.8025, Train Accuracy: 39.73%


Epoch 61/100: 100%|██████████| 400/400 [00:15<00:00, 26.13it/s]


Epoch 61 - ELBO Loss: 53428.3244


Epoch 62/100: 100%|██████████| 400/400 [00:15<00:00, 26.28it/s]


Epoch 62 - ELBO Loss: 52320.5982


Epoch 63/100: 100%|██████████| 400/400 [00:15<00:00, 26.30it/s]


Epoch 63 - ELBO Loss: 51219.3077


Epoch 64/100: 100%|██████████| 400/400 [00:15<00:00, 26.30it/s]


Epoch 64 - ELBO Loss: 50175.2287


Epoch 65/100: 100%|██████████| 400/400 [00:15<00:00, 26.20it/s]


Epoch 65 - ELBO Loss: 49174.0993


Epoch 66/100: 100%|██████████| 400/400 [00:15<00:00, 26.04it/s]


Epoch 66 - ELBO Loss: 48221.9737


Epoch 67/100: 100%|██████████| 400/400 [00:15<00:00, 26.16it/s]


Epoch 67 - ELBO Loss: 47332.4573


Epoch 68/100: 100%|██████████| 400/400 [00:15<00:00, 26.21it/s]


Epoch 68 - ELBO Loss: 46439.2832


Epoch 69/100: 100%|██████████| 400/400 [00:15<00:00, 26.23it/s]


Epoch 69 - ELBO Loss: 45540.5897


Epoch 70/100: 100%|██████████| 400/400 [00:15<00:00, 26.26it/s]
Calculating accuracy for epoch 70: 100%|██████████| 400/400 [00:04<00:00, 85.42it/s]


Epoch 70 - ELBO Loss: 44800.3528, Train Accuracy: 38.40%


Epoch 71/100: 100%|██████████| 400/400 [00:15<00:00, 26.20it/s]


Epoch 71 - ELBO Loss: 43975.7027


Epoch 72/100: 100%|██████████| 400/400 [00:15<00:00, 26.07it/s]


Epoch 72 - ELBO Loss: 43226.6712


Epoch 73/100: 100%|██████████| 400/400 [00:15<00:00, 25.88it/s]


Epoch 73 - ELBO Loss: 42498.1186


Epoch 74/100: 100%|██████████| 400/400 [00:15<00:00, 25.84it/s]


Epoch 74 - ELBO Loss: 41709.6990


Epoch 75/100: 100%|██████████| 400/400 [00:15<00:00, 26.08it/s]


Epoch 75 - ELBO Loss: 40882.0571


Epoch 76/100: 100%|██████████| 400/400 [00:15<00:00, 25.40it/s]


Epoch 76 - ELBO Loss: 40189.0458


Epoch 77/100: 100%|██████████| 400/400 [00:16<00:00, 24.98it/s]


Epoch 77 - ELBO Loss: 39566.4321


Epoch 78/100: 100%|██████████| 400/400 [00:15<00:00, 25.62it/s]


Epoch 78 - ELBO Loss: 38933.9014


Epoch 79/100: 100%|██████████| 400/400 [00:15<00:00, 25.95it/s]


Epoch 79 - ELBO Loss: 38173.5468


Epoch 80/100: 100%|██████████| 400/400 [00:15<00:00, 26.13it/s]
Calculating accuracy for epoch 80: 100%|██████████| 400/400 [00:04<00:00, 85.17it/s]


Epoch 80 - ELBO Loss: 37698.0800, Train Accuracy: 38.30%


Epoch 81/100: 100%|██████████| 400/400 [00:15<00:00, 26.10it/s]


Epoch 81 - ELBO Loss: 37000.1174


Epoch 82/100: 100%|██████████| 400/400 [00:15<00:00, 26.08it/s]


Epoch 82 - ELBO Loss: 36407.0813


Epoch 83/100: 100%|██████████| 400/400 [00:38<00:00, 10.28it/s]


Epoch 83 - ELBO Loss: 35796.3997


Epoch 84/100: 100%|██████████| 400/400 [00:42<00:00,  9.38it/s]


Epoch 84 - ELBO Loss: 35298.6446


Epoch 85/100: 100%|██████████| 400/400 [00:36<00:00, 10.87it/s]


Epoch 85 - ELBO Loss: 34876.1795


Epoch 86/100: 100%|██████████| 400/400 [00:34<00:00, 11.60it/s]


Epoch 86 - ELBO Loss: 34299.9979


Epoch 87/100: 100%|██████████| 400/400 [00:21<00:00, 18.75it/s]


Epoch 87 - ELBO Loss: 33779.3854


Epoch 88/100: 100%|██████████| 400/400 [00:18<00:00, 21.22it/s]


Epoch 88 - ELBO Loss: 33226.1927


Epoch 89/100: 100%|██████████| 400/400 [00:18<00:00, 21.21it/s]


Epoch 89 - ELBO Loss: 32786.5404


Epoch 90/100: 100%|██████████| 400/400 [00:18<00:00, 21.20it/s]
Calculating accuracy for epoch 90: 100%|██████████| 400/400 [00:05<00:00, 70.25it/s]


Epoch 90 - ELBO Loss: 32319.0790, Train Accuracy: 36.69%


Epoch 91/100: 100%|██████████| 400/400 [00:34<00:00, 11.67it/s]


Epoch 91 - ELBO Loss: 31760.2471


Epoch 92/100: 100%|██████████| 400/400 [00:34<00:00, 11.62it/s]


Epoch 92 - ELBO Loss: 31365.4672


Epoch 93/100: 100%|██████████| 400/400 [00:33<00:00, 11.81it/s]


Epoch 93 - ELBO Loss: 31071.4980


Epoch 94/100: 100%|██████████| 400/400 [00:33<00:00, 11.92it/s]


Epoch 94 - ELBO Loss: 30556.8151


Epoch 95/100: 100%|██████████| 400/400 [00:33<00:00, 11.79it/s]


Epoch 95 - ELBO Loss: 30033.3592


Epoch 96/100: 100%|██████████| 400/400 [00:33<00:00, 11.95it/s]


Epoch 96 - ELBO Loss: 29791.8375


Epoch 97/100: 100%|██████████| 400/400 [00:33<00:00, 11.96it/s]


Epoch 97 - ELBO Loss: 29346.1499


Epoch 98/100: 100%|██████████| 400/400 [00:33<00:00, 11.92it/s]


Epoch 98 - ELBO Loss: 28851.7287


Epoch 99/100: 100%|██████████| 400/400 [00:33<00:00, 11.88it/s]


Epoch 99 - ELBO Loss: 28536.5759


Epoch 100/100: 100%|██████████| 400/400 [00:34<00:00, 11.68it/s]
Calculating accuracy for epoch 100: 100%|██████████| 400/400 [00:10<00:00, 39.72it/s]

Epoch 100 - ELBO Loss: 28198.1724, Train Accuracy: 37.06%





In [None]:
# Plot all results including weight and bias statistics
plot_training_results_with_stats(losses, accuracies, accuracy_epochs, weight_stats, bias_stats)

In [None]:
stopdeh

In [None]:
weight_means = []
weight_stds = []
bias_means = []
bias_stds = []

for name, param in pyro.get_param_store().items():
    print(f"Parameter: {name}, Mean: {param.mean().item()}, Std: {param.std().item()}")
    if 'AutoDiagonalNormal.loc' in name:
        weight_means.append(param.mean().item())
        weight_stds.append(param.std().item())
    elif 'AutoDiagonalNormal.scale' in name:
        bias_means.append(param.mean().item())
        bias_stds.append(param.std().item())

In [None]:
accuracies

In [None]:
weight_stats

In [None]:
bias_stats

In [None]:
# create boxplot for weight means and stds
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.boxplot([weight_means, weight_stds], labels=['Weight Means', 'Weight Stds'])
plt.title('Weight Statistics')
plt.xlabel('Statistic')
plt.ylabel('Value')
plt.grid(True)
plt.subplot(1, 2, 2)
plt.boxplot([bias_means, bias_stds], labels=['Bias Means', 'Bias Stds'])
plt.title('Bias Statistics')
plt.xlabel('Statistic')
plt.ylabel('Value')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
weight_means

In [None]:
#pyro.clear_param_store()

# Ensure model and guide are on the correct device
#bayesian_model.to(device)
#guide.to(device)

#train_loader, test_loader = load_data(batch_size=54)
#train_svi(bayesian_model, guide, svi, train_loader, num_epochs=0)

In [14]:
# save the model
model_path = 'results_eurosat/bayesian_cnn_model_std10_100_epoch_WGBN.pth'
torch.save(bayesian_model.state_dict(), model_path)

# save the guide
guide_path = 'results_eurosat/bayesian_cnn_guide_std10_100_epoch_guide_WGBN.pth'
torch.save(guide.state_dict(), guide_path)

# save the pyro parameter store
pyro_param_store_path = 'results_eurosat/pyro_param_store_std10_100_epoch_WGBN.pkl'
pyro.get_param_store().save(pyro_param_store_path)

In [None]:
import numpy as np

In [15]:
# print confusion matrix
import numpy as np
from sklearn.metrics import confusion_matrix


def predict_data(model, loader_of_interest, num_samples=10):
    model.eval()
    guide.eval()

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in tqdm(loader_of_interest, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

            for i in range(num_samples):
                guide_trace = pyro.poutine.trace(guide).get_trace(images)
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    return all_labels, all_predictions

In [16]:
train_labels, train_predictions = predict_data(bayesian_model, train_loader, num_samples=10)

Evaluating: 100%|██████████| 400/400 [01:29<00:00,  4.46it/s]


In [17]:
train_cm = confusion_matrix(train_labels, train_predictions)

In [18]:
#print accuracy from confusion matrix
train_accuracy = np.trace(train_cm) / np.sum(train_cm)
print(f"Train accuracy from confusion matrix: {train_accuracy * 100:.6f}%")

Train accuracy from confusion matrix: 59.620370%


In [None]:
all_labels, all_predictions = predict_data(bayesian_model, test_loader, num_samples=10)

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
cm = confusion_matrix(all_labels, all_predictions)

In [None]:
#print accuracy from confusion matrix
accuracy = np.trace(cm) / np.sum(cm)
print(f"Accuracy from confusion matrix: {accuracy * 100:.6f}%")

In [None]:
# print pyro parameters
for name, value in pyro.get_param_store().items():
    print(f"{name}: {value.shape} - {value.mean().item():.4f} ± {value.std().item():.4f}")

60.092593% for the 10 epoch

In [None]:
# plot the confusion matrix
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, classes):
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, cm[i, j],
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
            
    # make a mark to the diagonal
    plt.plot([0, cm.shape[1]-1], [0, cm.shape[0]-1], color='red', linestyle='--', linewidth=2)

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

# Plot the confusion matrix
class_names = ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial',
               'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
plot_confusion_matrix(cm, class_names)

In [None]:
# save the model
#model_path = 'results_eurosat/bayesian_cnn_model_std10_100_epoch.pth'
#torch.save(bayesian_model.state_dict(), model_path)

# save the guide
#guide_path = 'results_eurosat/bayesian_cnn_guide_std10_100_epoch_guide.pth'
#torch.save(guide.state_dict(), guide_path)

# save the pyro parameter store
#pyro_param_store_path = 'results_eurosat/pyro_param_store_std10_100_epoch.pkl'
#pyro.get_param_store().save(pyro_param_store_path)

In [None]:
kataguediemdeh

In [None]:
def train_svi_early_save(model, guide, svi, train_loader, num_epochs=10, patience=3, min_delta=0.001):
    # Clear parameter store only ONCE at the beginning
    pyro.clear_param_store()
    model.train()
    
    # Ensure model is on the correct device
    model.to(device)
    
    # Lists to store losses and accuracies
    epoch_losses = []
    epoch_accuracies = []
    accuracy_epochs = []
    
    # Early stopping variables
    best_accuracy = 0.0
    best_epoch = 0
    patience_counter = 0
    best_model_state = None
    best_guide_state = None
    best_pyro_params = None
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            loss = svi.step(images, labels)
            epoch_loss += loss
            num_batches += 1
            
        avg_loss = epoch_loss / num_batches
        epoch_losses.append(avg_loss)
        
        # Calculate accuracy every 10 epochs (and on the first and last epoch)
        if (epoch + 1) % 10 == 0 or epoch == 0 or epoch == num_epochs - 1:
            model.eval()
            guide.eval()
            
            correct_predictions = 0
            total_samples = 0
            
            with torch.no_grad():
                for images, labels in tqdm(train_loader, desc=f"Calculating accuracy for epoch {epoch+1}"):
                    images, labels = images.to(device), labels.to(device)
                    
                    # Sample from the guide to get model parameters
                    guide_trace = pyro.poutine.trace(guide).get_trace(images)
                    replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                    
                    # Get predictions
                    logits = replayed_model(images)
                    predictions = torch.argmax(logits, dim=1)
                    
                    # Count correct predictions
                    correct_predictions += (predictions == labels).sum().item()
                    total_samples += labels.size(0)
            
            epoch_accuracy = correct_predictions / total_samples
            epoch_accuracies.append(epoch_accuracy)
            accuracy_epochs.append(epoch + 1)
            
            # Check for improvement
            if epoch_accuracy > best_accuracy + min_delta:
                best_accuracy = epoch_accuracy
                best_epoch = epoch + 1
                patience_counter = 0
                
                # Save best model states
                best_model_state = model.state_dict().copy()
                best_guide_state = guide.state_dict().copy()
                best_pyro_params = pyro.get_param_store().get_state().copy()
                
                print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}, Train Accuracy: {epoch_accuracy*100:.2f}% *** NEW BEST ***")
            else:
                patience_counter += 1
                print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}, Train Accuracy: {epoch_accuracy*100:.2f}% (Best: {best_accuracy*100:.2f}% at epoch {best_epoch})")
                
                # Early stopping check
                if patience_counter >= patience:
                    print(f"\nEarly stopping triggered! No improvement for {patience} evaluations.")
                    print(f"Best accuracy: {best_accuracy*100:.2f}% at epoch {best_epoch}")
                    
                    # Restore best model
                    model.load_state_dict(best_model_state)
                    guide.load_state_dict(best_guide_state)
                    pyro.get_param_store().set_state(best_pyro_params)
                    
                    break
            
            model.train()  # Set back to training mode
        else:
            print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}")
    
    # If training completed without early stopping, still restore best model
    if patience_counter < patience and best_model_state is not None:
        print(f"\nTraining completed. Restoring best model from epoch {best_epoch} (accuracy: {best_accuracy*100:.2f}%)")
        model.load_state_dict(best_model_state)
        guide.load_state_dict(best_guide_state)
        pyro.get_param_store().set_state(best_pyro_params)
    
    return epoch_losses, epoch_accuracies, accuracy_epochs, best_epoch, best_accuracy

In [None]:
pyro.clear_param_store()

# Ensure model and guide are on the correct device
bayesian_model.to(device)
guide.to(device)

train_loader, test_loader = load_data(batch_size=54)

# Train with early stopping
losses, accuracies, accuracy_epochs, best_epoch, best_accuracy = train_svi(
    bayesian_model, guide, svi, train_loader, 
    num_epochs=100, 
    patience=3,  # Stop if no improvement for 3 accuracy evaluations (30 epochs)
    min_delta=0.001  # Minimum improvement threshold (0.1%)
)

print(f"\nFinal Results:")
print(f"Best training accuracy: {best_accuracy*100:.2f}% at epoch {best_epoch}")

# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(range(1, len(losses) + 1), losses)
plt.axvline(x=best_epoch, color='red', linestyle='--', label=f'Best Model (Epoch {best_epoch})')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('ELBO Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(accuracy_epochs, accuracies, 'o-')
plt.axvline(x=best_epoch, color='red', linestyle='--', label=f'Best Model (Epoch {best_epoch})')
plt.axhline(y=best_accuracy, color='red', linestyle=':', alpha=0.7)
plt.title('Training Accuracy (Every 10 Epochs)')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## Tensorboard

In [None]:
from torch.utils.tensorboard import SummaryWriter
import os

In [None]:
def train_svi_with_tensorboard(model, guide, svi, train_loader, num_epochs=10, log_dir='runs/bayesian_cnn'):
    # Clear parameter store only ONCE at the beginning
    pyro.clear_param_store()
    model.train()
    
    # Ensure model is on the correct device
    model.to(device)
    
    # Initialize TensorBoard writer
    writer = SummaryWriter(log_dir)
    
    # Lists to store losses and accuracies
    epoch_losses = []
    epoch_accuracies = []
    accuracy_epochs = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            loss = svi.step(images, labels)
            epoch_loss += loss
            num_batches += 1
            
        avg_loss = epoch_loss / num_batches
        epoch_losses.append(avg_loss)
        
        # Log loss to TensorBoard every epoch
        writer.add_scalar('Loss/ELBO', avg_loss, epoch + 1)
        
        # Calculate accuracy every 10 epochs (and on the first and last epoch)
        if (epoch + 1) % 10 == 0 or epoch == 0 or epoch == num_epochs - 1:
            model.eval()
            guide.eval()
            
            correct_predictions = 0
            total_samples = 0
            
            with torch.no_grad():
                for images, labels in tqdm(train_loader, desc=f"Calculating accuracy for epoch {epoch+1}"):
                    images, labels = images.to(device), labels.to(device)
                    
                    # Sample from the guide to get model parameters
                    guide_trace = pyro.poutine.trace(guide).get_trace(images)
                    replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                    
                    # Get predictions
                    logits = replayed_model(images)
                    predictions = torch.argmax(logits, dim=1)
                    
                    # Count correct predictions
                    correct_predictions += (predictions == labels).sum().item()
                    total_samples += labels.size(0)
            
            epoch_accuracy = correct_predictions / total_samples
            epoch_accuracies.append(epoch_accuracy)
            accuracy_epochs.append(epoch + 1)
            
            # Log accuracy to TensorBoard
            writer.add_scalar('Accuracy/Train', epoch_accuracy, epoch + 1)
            
            model.train()  # Set back to training mode
            
            print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}, Train Accuracy: {epoch_accuracy*100:.2f}%")
        else:
            print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}")
    
    # Close the writer
    writer.close()
    
    return epoch_losses, epoch_accuracies, accuracy_epochs

In [None]:
# Train with TensorBoard logging
losses, accuracies, accuracy_epochs = train_svi_with_tensorboard(
    bayesian_model, guide, svi, train_loader, 
    num_epochs=100,
    log_dir='runs/eurosat_bayesian_cnn_experiment'
)

In [None]:
# Train with TensorBoard logging
losses, accuracies, accuracy_epochs = train_svi_with_tensorboard(
    bayesian_model, guide, svi, train_loader, 
    num_epochs=100,
    log_dir='runs/eurosat_bayesian_cnn_experiment'
)

Feature TODO:
1. Record loss after each epoch
2. Send result to GPU