In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch

import pickle
from pyro.infer.autoguide import AutoDiagonalNormal

from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

In [3]:
device = torch.device("cuda")

In [4]:
class BayesianCNNSingleFC(PyroModule):
    def __init__(self, num_classes):
        super().__init__()

        prior_mu = 0.
        #prior_sigma = 0.1 #accuracy 13.203704% 2 epochs
        #prior_sigma = 1. #accuracy 31% 2 epochs
        prior_sigma = torch.tensor(10., device=device) #accuracy 45% 10 epochs
        #prior_sigma = 100 #accuracy 21% 10 epochs

        self.conv1 = PyroModule[nn.Conv2d](3, 32, kernel_size=5, stride=1, padding=2)
        self.conv1.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([32, 3, 5, 5]).to_event(4))
        self.conv1.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([32]).to_event(1))

        self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=5, stride=1, padding=2) #initially padding=1 kernel_size=3, without stride
        self.conv2.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([64, 32, 5, 5]).to_event(4))
        self.conv2.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([64]).to_event(1))

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = PyroModule[nn.Linear](64 * 16 * 16, num_classes)
        self.fc1.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([num_classes, 64 * 16 * 16]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([num_classes]).to_event(1))

    def forward(self, x, y=None):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        logits = self.fc1(x)
        
        # THIS IS THE MISSING PIECE: Define the likelihood
        if y is not None:
            with pyro.plate("data", x.shape[0]):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        
        return logits

In [5]:
import pyro
import pyro.distributions as dist
import pyro.nn
import torch
import torch.nn as nn
from pyro.nn import PyroModule, PyroParam
from torch.distributions import constraints

class CustomGuide(PyroModule):
    def __init__(self, num_classes):
        super().__init__()

        # conv1 weights and bias
        self.conv1_weight_loc = PyroParam(torch.randn(32, 3, 5, 5) * 0.1)
        self.conv1_weight_scale = PyroParam(torch.ones(32, 3, 5, 5) * 0.1, constraint=constraints.positive)
        self.conv1_bias_loc = PyroParam(torch.randn(32) * 0.1)
        self.conv1_bias_scale = PyroParam(torch.ones(32) * 0.1, constraint=constraints.positive)

        # conv2 weights and bias
        self.conv2_weight_loc = PyroParam(torch.randn(64, 32, 5, 5) * 0.1)
        self.conv2_weight_scale = PyroParam(torch.ones(64, 32, 5, 5) * 0.1, constraint=constraints.positive)
        self.conv2_bias_loc = PyroParam(torch.randn(64) * 0.1)
        self.conv2_bias_scale = PyroParam(torch.ones(64) * 0.1, constraint=constraints.positive)

        # fc1 weights and bias
        self.fc1_weight_loc = PyroParam(torch.randn(num_classes, 64 * 16 * 16) * 0.1)
        self.fc1_weight_scale = PyroParam(torch.ones(num_classes, 64 * 16 * 16) * 0.1, constraint=constraints.positive)
        self.fc1_bias_loc = PyroParam(torch.randn(num_classes) * 0.1)
        self.fc1_bias_scale = PyroParam(torch.ones(num_classes) * 0.1, constraint=constraints.positive)

    def forward(self, x, y=None):
        pyro.sample("conv1.weight", dist.Normal(self.conv1_weight_loc, self.conv1_weight_scale).to_event(4))
        pyro.sample("conv1.bias", dist.Normal(self.conv1_bias_loc, self.conv1_bias_scale).to_event(1))
        pyro.sample("conv2.weight", dist.Normal(self.conv2_weight_loc, self.conv2_weight_scale).to_event(4))
        pyro.sample("conv2.bias", dist.Normal(self.conv2_bias_loc, self.conv2_bias_scale).to_event(1))
        pyro.sample("fc1.weight", dist.Normal(self.fc1_weight_loc, self.fc1_weight_scale).to_event(2))
        pyro.sample("fc1.bias", dist.Normal(self.fc1_bias_loc, self.fc1_bias_scale).to_event(1))


In [6]:
def load_data(batch_size=54):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1809, 0.1331, 0.1137])
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=False)

    torch.manual_seed(42)
    
    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

In [7]:
num_classes = 10
bayesian_model = BayesianCNNSingleFC(num_classes=num_classes).to(device)

In [8]:
model_path = 'results_eurosat/bayesian_cnn_model_std10_cust10_epoch.pth'
guide_path = 'results_eurosat/bayesian_cnn_guide_std10_cust10_epoch_guide.pth'
pyro_param_store_path = 'results_eurosat/pyro_param_store_std10_cust10_epoch.pkl'

#guide = AutoDiagonalNormal(bayesian_model).to(device)
guide = CustomGuide(num_classes=num_classes).to(device)
#guide.load_state_dict(torch.load(guide_path))

pyro.get_param_store().set_state(torch.load(pyro_param_store_path,weights_only=False))

In [9]:
def predict_data_probs(model, test_loader, num_samples=10):
    model.eval()

    all_labels = []
    all_predictions = []
    all_logits = []
    all_probs = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

            for i in range(num_samples):
                guide_trace = pyro.poutine.trace(guide).get_trace(images)
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())
            all_logits.extend(avg_logits.cpu().numpy())
            all_probs.extend(F.softmax(avg_logits, dim=1).cpu().numpy())

    return all_labels, all_predictions, all_logits, all_probs

In [10]:
# Get all weight mean parameters
weight_mean_params = {}
param_store = pyro.get_param_store()

for name, param in param_store.items():
    if 'weight' in name and 'loc' in name:  # 'loc' is the mean parameter in AutoDiagonalNormal
        weight_mean_params[name] = param

In [11]:
for name, param in param_store.items():
    print(f"{name}: {param.shape} - {param.mean().item()}")

conv1_weight_loc: torch.Size([32, 3, 5, 5]) - 0.00356109905987978
conv1_weight_scale: torch.Size([32, 3, 5, 5]) - 0.1724870204925537
conv1_bias_loc: torch.Size([32]) - 0.7791784405708313
conv1_bias_scale: torch.Size([32]) - 0.17147037386894226
conv2_weight_loc: torch.Size([64, 32, 5, 5]) - -0.24098685383796692
conv2_weight_scale: torch.Size([64, 32, 5, 5]) - 0.9782638549804688
conv2_bias_loc: torch.Size([64]) - -0.38410523533821106
conv2_bias_scale: torch.Size([64]) - 0.8057948350906372
fc1_weight_loc: torch.Size([10, 16384]) - 0.00266614044085145
fc1_weight_scale: torch.Size([10, 16384]) - 4.2342424392700195
fc1_bias_loc: torch.Size([10]) - 0.04563991725444794
fc1_bias_scale: torch.Size([10]) - 0.49544063210487366


In [12]:
for name, value in pyro.get_param_store().items():
    print(name, value.shape)

conv1_weight_loc torch.Size([32, 3, 5, 5])
conv1_weight_scale torch.Size([32, 3, 5, 5])
conv1_bias_loc torch.Size([32])
conv1_bias_scale torch.Size([32])
conv2_weight_loc torch.Size([64, 32, 5, 5])
conv2_weight_scale torch.Size([64, 32, 5, 5])
conv2_bias_loc torch.Size([64])
conv2_bias_scale torch.Size([64])
fc1_weight_loc torch.Size([10, 16384])
fc1_weight_scale torch.Size([10, 16384])
fc1_bias_loc torch.Size([10])
fc1_bias_scale torch.Size([10])


In [14]:
train_loader, test_loader = load_data(batch_size=54)

In [15]:
all_labels, all_predictions, all_logits, all_probs = predict_data_probs(bayesian_model, test_loader, num_samples=10)

Evaluating: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s]


In [16]:
import numpy as np
from sklearn.metrics import confusion_matrix

In [17]:
cm = confusion_matrix(all_labels, all_predictions)


In [18]:
accuracy = np.trace(cm) / np.sum(cm)
print(f"Accuracy from confusion matrix: {accuracy * 100:.6f}%")

Accuracy from confusion matrix: 10.240741%
