In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch

import pickle

In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [4]:
init_weight_dict = {'conv1.weight': (-0.0002266301744384691, 0.07788292318582535), 
                    'conv1.bias': (-0.06533964723348618, 0.11781848967075348), 
                    'conv2.weight': (-0.0124916797503829, 0.044476691633462906), 
                    'conv2.bias': (-0.010315056890249252, 0.06136268749833107), 
                    'fc1.weight': (-0.004733316134661436, 0.04221488535404205), 
                    'fc1.bias': (-0.004896007478237152, 0.11825607717037201)}

In [5]:
# change the CNN class to have 1 fc layer only

class CNNSingleFC(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 16 * 16, num_classes)  # Single FC layer

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))   # -> (32, 32, 32)
        x = self.pool(F.relu(self.conv2(x)))   # -> (64, 16, 16)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)                         # Single FC layer
        return x

In [6]:
# convert the CNN class to a Bayesian CNN with 1 fc layer only, with the weight initialization from init_weight_dict
class BayesianCNNSingleFC(PyroModule):
    def __init__(self, num_classes):
        super().__init__()
        #self.conv1 = PyroModule[nn.Conv2d](3, 32, kernel_size=5, stride=1, padding=2)
        #self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=5, stride=1, padding=2)
        #self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        #self.fc1 = PyroModule[nn.Linear](64 * 16 * 16, num_classes)

        self.conv1 = PyroModule[nn.Conv2d](3, 32, kernel_size=3, padding=1)
        self.conv1.weight = PyroSample(dist.Normal(init_weight_dict['conv1.weight'][0], init_weight_dict['conv1.weight'][1]).expand([32, 3, 3, 3]).to_event(4))
        self.conv1.bias = PyroSample(dist.Normal(init_weight_dict['conv1.bias'][0], init_weight_dict['conv1.bias'][1]).expand([32]).to_event(1))

        self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=3, padding=1)
        self.conv2.weight = PyroSample(dist.Normal(init_weight_dict['conv2.weight'][0], init_weight_dict['conv2.weight'][1]).expand([64, 32, 3, 3]).to_event(4))
        self.conv2.bias = PyroSample(dist.Normal(init_weight_dict['conv2.bias'][0], init_weight_dict['conv2.bias'][1]).expand([64]).to_event(1))

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = PyroModule[nn.Linear](64 * 16 * 16, num_classes)
        self.fc1.weight = PyroSample(dist.Normal(init_weight_dict['fc1.weight'][0], init_weight_dict['fc1.weight'][1]).expand([num_classes, 64 * 16 * 16]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(init_weight_dict['fc1.bias'][0], init_weight_dict['fc1.bias'][1]).expand([num_classes]).to_event(1))

    #def forward(self, x, y=None):
    #    x = self.pool(F.relu(self.conv1(x)))
    #    x = self.pool(F.relu(self.conv2(x)))
    #    x = x.view(x.size(0), -1)
    #    x = self.fc1(x)

    #    if y is not None:
    #        with pyro.plate("data", x.shape[0]):
    #            pyro.sample("obs", dist.Categorical(logits=logits), obs=y)

    #    return x

    def forward(self, x, y=None):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        logits = self.fc1(x)
        
        # THIS IS THE MISSING PIECE: Define the likelihood
        if y is not None:
            with pyro.plate("data", x.shape[0]):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        
        return logits

In [7]:
def load_data(batch_size=64):  # Changed from 54 to 64 to match deterministic CNN
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1809, 0.1331, 0.1137])
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=False)

    # Use fixed random seed for reproducible splits
    torch.manual_seed(42)
    #train_size = int(0.8 * len(dataset))
    #test_size = len(dataset) - train_size
    #train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    return train_loader, test_loader

In [8]:
import pyro
import pyro.infer
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam
from tqdm import tqdm

# 1. Instantiate model and guide - fix device consistency
num_classes = 10
bayesian_model = BayesianCNNSingleFC(num_classes=num_classes).to(device)  # Use .to(device) instead of .cuda()
guide = AutoDiagonalNormal(bayesian_model)

# 2. Optimizer and SVI - increase learning rate for better convergence
optimizer = Adam({"lr": 1e-3})  # Increased from 1e-4 to 1e-3
svi = pyro.infer.SVI(model=bayesian_model,
                     guide=guide,
                     optim=optimizer,
                     loss=pyro.infer.Trace_ELBO())

# 3. Training function
def train_svi(model, guide, svi, train_loader, num_epochs=10):
    # Clear parameter store only ONCE at the beginning
    pyro.clear_param_store()
    model.train()
    
    # Ensure model is on the correct device
    model.to(device)
    guide.to(device)
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            loss = svi.step(images, labels)
            epoch_loss += loss
            num_batches += 1
            
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}")


In [9]:
pyro.clear_param_store()

# Ensure model and guide are on the correct device
bayesian_model.to(device)
guide.to(device)

train_loader, test_loader = load_data(batch_size=54)
train_svi(bayesian_model, guide, svi, train_loader, num_epochs=2)

Epoch 1/2: 100%|██████████| 400/400 [02:02<00:00,  3.28it/s]


Epoch 1 - ELBO Loss: 197077.5633


Epoch 2/2: 100%|██████████| 400/400 [00:47<00:00,  8.46it/s]

Epoch 2 - ELBO Loss: 110739.0545





In [10]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

def evaluate_svi(model, guide, test_loader, num_samples=10):
    model.eval()
    correct = 0
    total = 0

    #device = next(model.parameters()).device  # Get the device where the model is allocated

    with torch.no_grad():  # Add no_grad for efficiency
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)  # Fix: move to device

            # Accumulate logits from multiple samples
            logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

            for i in range(num_samples):
                # Sample from the guide (posterior) and replay through model
                guide_trace = pyro.poutine.trace(guide).get_trace()
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            # Average the logits across samples
            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"Accuracy over {num_samples} MC samples: {accuracy * 100:.2f}%")
    return accuracy


In [11]:
evaluate_svi(bayesian_model, guide, test_loader, num_samples=10)

Evaluating: 100%|██████████| 100/100 [00:39<00:00,  2.54it/s]

Accuracy over 10 MC samples: 10.65%





0.10648148148148148

In [None]:
for name, value in pyro.get_param_store().items():
    print(f"{name}: {value.shape}")
    print(value)


In [None]:
# print confusion matrix
import numpy as np
from sklearn.metrics import confusion_matrix


def predict_data(model, test_loader, num_samples=10):
    model.eval()

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

            for i in range(num_samples):
                guide_trace = pyro.poutine.trace(guide).get_trace()
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    return all_labels, all_predictions

In [None]:
all_labels, all_predictions = predict_data(bayesian_model, test_loader, num_samples=10)

In [None]:
cm = confusion_matrix(all_labels, all_predictions)

In [None]:
#print accuracy from confusion matrix
accuracy = np.trace(cm) / np.sum(cm)
print(f"Accuracy from confusion matrix: {accuracy * 100:.2f}%")

In [None]:
# plot the confusion matrix
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, classes):
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, cm[i, j],
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
            
    # make a mark to the diagonal
    plt.plot([0, cm.shape[1]-1], [0, cm.shape[0]-1], color='red', linestyle='--', linewidth=2)

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

# Plot the confusion matrix
class_names = ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial',
               'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
plot_confusion_matrix(cm, class_names)

In [None]:
def comprehensive_prediction_check_batch(model, guide, test_loader, batch_num=0, num_samples=10):
    model.eval()
    
    # Check multiple samples from posterior
    all_preds_single = []
    all_preds_mc = []
    
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            if i == batch_num:
                images = images.to(device)
                
                # Single sample prediction (your current method)
                guide_trace = pyro.poutine.trace(guide).get_trace()
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits_single = replayed_model(images)
                preds_single = torch.argmax(logits_single, dim=1)
                all_preds_single.extend(preds_single.cpu().numpy())
                
                # Multiple samples (Monte Carlo)
                logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)
                for j in range(num_samples):
                    guide_trace = pyro.poutine.trace(guide).get_trace()
                    replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                    logits_mc[j] = replayed_model(images)
                
                avg_logits = logits_mc.mean(dim=0)
                preds_mc = torch.argmax(avg_logits, dim=1)
                all_preds_mc.extend(preds_mc.cpu().numpy())


                
                break  # Just check the specified batch
    
    print(f"Batch {batch_num} - Single sample prediction distribution:", np.bincount(all_preds_single, minlength=num_classes))
    print(f"Batch {batch_num} - MC average prediction distribution:\t", np.bincount(all_preds_mc, minlength=num_classes))
    true_labels = labels.cpu().numpy()
    print("True labels distribution:\t\t\t", np.bincount(true_labels, minlength=num_classes))

    #print the confusion matrix for this batch
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(true_labels, all_preds_mc)
    print("Confusion Matrix for Batch", batch_num)
    print(cm)
    
    # Check if guide has learned meaningful parameters
    print("\nGuide parameter statistics:")
    for name, param in pyro.get_param_store().items():
        if 'loc' in name:
            print(f"{name}: mean={param.mean().item():.4f}, std={param.std().item():.4f}")
        elif 'scale' in name:
            print(f"{name}: mean={param.mean().item():.4f}, min={param.min().item():.4f}, max={param.max().item():.4f}")
    
    #plot the confusion matrix
    plot_confusion_matrix(cm, class_names)

    return None

In [None]:
comprehensive_prediction_check_batch(bayesian_model, 
                                     guide, 
                                     test_loader, 
                                     batch_num=4, 
                                     num_samples=100)

In [None]:
# show 1 sample prediction and put the label on the image in matplotlib
import matplotlib.pyplot as plt

def show_sample_prediction(model, test_loader, num_samples=10, image_idx=0):
    model.eval()
    images, labels = next(iter(test_loader))
    images, labels = images.to(device), labels.to(device)

    logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

    for i in range(num_samples):
        guide_trace = pyro.poutine.trace(guide).get_trace()
        replayed_model = pyro.poutine.replay(model, trace=guide_trace)
        logits = replayed_model(images)
        logits_mc[i] = logits

    avg_logits = logits_mc.mean(dim=0)
    predictions = torch.argmax(avg_logits, dim=1)

    #print the logits and its label in descending order
    sorted_logits, sorted_indices = torch.sort(avg_logits[0], descending=True)

    # for the correct label, print in which rank is the label in the sorted logits
    correct_label = labels[0].item()
    correct_rank = (sorted_indices == correct_label).nonzero(as_tuple=True)[0].item()
    #print(f"Correct label {correct_label} is at rank {correct_rank + 1} in the sorted logits.")

    # Show the first image and its prediction
    plt.imshow(images[image_idx].cpu().permute(1, 2, 0)* 0.1137 + 0.4082)
    #plt.title(f"Predicted: {predictions[0]}, True: {labels[0]} \nRank: {correct_rank + 1}")
    plt.title(f"Rank: {correct_rank + 1}")
    plt.axis('off')
    plt.show()

    #converts the logits to probabilities
    probabilities = F.softmax(avg_logits, dim=1)

    print("Logits (sorted):")
    for i in range(len(sorted_logits)):
        print(f"Class {sorted_indices[i].item()}: {sorted_logits[i].item()} ({probabilities[0][sorted_indices[i]].item()*100:.2f}%)")

    return None

In [None]:
show_sample_prediction(bayesian_model, test_loader, num_samples=10, image_idx=1)