In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch

import pickle

In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda")
#device = "cpu"

In [4]:
class BayesianCNNSingleFC(PyroModule):
    def __init__(self, num_classes):
        super().__init__()

        prior_mu = 0.
        #prior_sigma = 0.1 #accuracy 13.203704% 2 epochs
        #prior_sigma = 1. #accuracy 31% 2 epochs
        prior_sigma = torch.tensor(10., device=device) #accuracy 45% 10 epochs
        #prior_sigma = 100 #accuracy 21% 10 epochs

        self.conv1 = PyroModule[nn.Conv2d](3, 32, kernel_size=5, stride=1, padding=2)
        self.conv1.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([32, 3, 5, 5]).to_event(4))
        self.conv1.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([32]).to_event(1))

        self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=5, stride=1, padding=2) #initially padding=1 kernel_size=3, without stride
        self.conv2.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([64, 32, 5, 5]).to_event(4))
        self.conv2.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([64]).to_event(1))

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = PyroModule[nn.Linear](64 * 16 * 16, num_classes)
        self.fc1.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([num_classes, 64 * 16 * 16]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([num_classes]).to_event(1))

    def forward(self, x, y=None):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        logits = self.fc1(x)
        
        # THIS IS THE MISSING PIECE: Define the likelihood
        if y is not None:
            with pyro.plate("data", x.shape[0]):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        
        return logits

In [5]:
def old_load_data(batch_size=54): 
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1809, 0.1331, 0.1137])
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=False)

    # Use fixed random seed for reproducible splits
    torch.manual_seed(42)
    #train_size = int(0.8 * len(dataset))
    #test_size = len(dataset) - train_size
    #train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    return train_loader, test_loader

In [6]:
def load_data(batch_size=54):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1809, 0.1331, 0.1137])
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=False)

    torch.manual_seed(42)
    
    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

In [7]:
num_classes = 10
bayesian_model = BayesianCNNSingleFC(num_classes=num_classes).to(device)

In [8]:
import pyro
import pyro.distributions as dist
import pyro.nn
from pyro.nn import PyroParam, PyroModule
#from torch.distributions import constraints
from pyro.distributions import constraints
import torch

class CustomGuide(PyroModule):
    def __init__(self, num_classes):
        super().__init__()

        # conv1 weights and bias
        self.conv1_weight_loc = PyroParam(torch.zeros(32, 3, 5, 5))
        self.conv1_weight_scale = PyroParam(torch.ones(32, 3, 5, 5) * 0.05, constraint=constraints.positive)
        self.conv1_bias_loc = PyroParam(torch.zeros(32))
        self.conv1_bias_scale = PyroParam(torch.ones(32) * 0.05, constraint=constraints.positive)

        # conv2 weights and bias
        self.conv2_weight_loc = PyroParam(torch.zeros(64, 32, 5, 5))
        self.conv2_weight_scale = PyroParam(torch.ones(64, 32, 5, 5) * 0.05, constraint=constraints.positive)
        self.conv2_bias_loc = PyroParam(torch.zeros(64))
        self.conv2_bias_scale = PyroParam(torch.ones(64) * 0.05, constraint=constraints.positive)

        # fc1 weights and bias
        self.fc1_weight_loc = PyroParam(torch.zeros(num_classes, 64 * 16 * 16))
        self.fc1_weight_scale = PyroParam(torch.ones(num_classes, 64 * 16 * 16) * 0.05, constraint=constraints.positive)
        self.fc1_bias_loc = PyroParam(torch.zeros(num_classes))
        self.fc1_bias_scale = PyroParam(torch.ones(num_classes) * 0.05, constraint=constraints.positive)

    def forward(self, x, y=None):
        pyro.sample("conv1.weight", dist.Normal(self.conv1_weight_loc, self.conv1_weight_scale).to_event(4))
        pyro.sample("conv1.bias", dist.Normal(self.conv1_bias_loc, self.conv1_bias_scale).to_event(1))
        pyro.sample("conv2.weight", dist.Normal(self.conv2_weight_loc, self.conv2_weight_scale).to_event(4))
        pyro.sample("conv2.bias", dist.Normal(self.conv2_bias_loc, self.conv2_bias_scale).to_event(1))
        pyro.sample("fc1.weight", dist.Normal(self.fc1_weight_loc, self.fc1_weight_scale).to_event(2))
        pyro.sample("fc1.bias", dist.Normal(self.fc1_bias_loc, self.fc1_bias_scale).to_event(1))


In [9]:
from pyro.nn import PyroParam, PyroModule
#from torch.distributions import constraints
from pyro.distributions import constraints
import torch

class CustomGuide(PyroModule):
    def __init__(self, num_classes, device='cpu'):
        super().__init__()

        # Initialize means and scales similarly to AutoDiagonalNormal default (mean=0, scale=0.1)
        init_loc = 0.0
        init_scale = 0.1  # AutoDiagonalNormal usually init scale ~0.1, but you can adjust

        self.conv1_weight_loc = PyroParam(torch.full((32, 3, 5, 5), init_loc, device=device))
        self.conv1_weight_scale = PyroParam(torch.full((32, 3, 5, 5), init_scale, device=device),
                                            constraint=constraints.softplus_positive)

        self.conv1_bias_loc = PyroParam(torch.full((32,), init_loc, device=device))
        self.conv1_bias_scale = PyroParam(torch.full((32,), init_scale, device=device),
                                         constraint=constraints.softplus_positive)

        self.conv2_weight_loc = PyroParam(torch.full((64, 32, 5, 5), init_loc, device=device))
        self.conv2_weight_scale = PyroParam(torch.full((64, 32, 5, 5), init_scale, device=device),
                                            constraint=constraints.softplus_positive)

        self.conv2_bias_loc = PyroParam(torch.full((64,), init_loc, device=device))
        self.conv2_bias_scale = PyroParam(torch.full((64,), init_scale, device=device),
                                         constraint=constraints.softplus_positive)

        self.fc1_weight_loc = PyroParam(torch.full((num_classes, 64 * 16 * 16), init_loc, device=device))
        self.fc1_weight_scale = PyroParam(torch.full((num_classes, 64 * 16 * 16), init_scale, device=device),
                                         constraint=constraints.softplus_positive)

        self.fc1_bias_loc = PyroParam(torch.full((num_classes,), init_loc, device=device))
        self.fc1_bias_scale = PyroParam(torch.full((num_classes,), init_scale, device=device),
                                       constraint=constraints.softplus_positive)

    def forward(self, x=None, y=None):
        # Sample latent variables in *exact same order* as model's latent variables
        pyro.sample("conv1.weight", dist.Normal(self.conv1_weight_loc, self.conv1_weight_scale).to_event(4))
        pyro.sample("conv1.bias", dist.Normal(self.conv1_bias_loc, self.conv1_bias_scale).to_event(1))

        pyro.sample("conv2.weight", dist.Normal(self.conv2_weight_loc, self.conv2_weight_scale).to_event(4))
        pyro.sample("conv2.bias", dist.Normal(self.conv2_bias_loc, self.conv2_bias_scale).to_event(1))

        pyro.sample("fc1.weight", dist.Normal(self.fc1_weight_loc, self.fc1_weight_scale).to_event(2))
        pyro.sample("fc1.bias", dist.Normal(self.fc1_bias_loc, self.fc1_bias_scale).to_event(1))


In [10]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroParam
from torch.distributions import constraints
import torch
import torch.nn.functional as F

class CustomVectorizedGuide(PyroModule):
    def __init__(self, num_classes):
        super().__init__()
        
        # Total parameter size
        self.param_shapes = {
            "conv1.weight": (32, 3, 5, 5),
            "conv1.bias": (32,),
            "conv2.weight": (64, 32, 5, 5),
            "conv2.bias": (64,),
            "fc1.weight": (num_classes, 64 * 16 * 16),
            "fc1.bias": (num_classes,)
        }

        total_params = sum(torch.tensor(shape).prod() for shape in self.param_shapes.values())
        self.total_size = total_params.item()

        # Vectorized parameters
        self.loc = PyroParam(torch.zeros(self.total_size))
        self.scale_unconstrained = PyroParam(torch.full((self.total_size,), -2.0))  # softplus(-2) ≈ 0.12

    def _unpack(self, vector):
        """Unpacks flat vector into a dict of shaped tensors"""
        params = {}
        offset = 0
        for name, shape in self.param_shapes.items():
            size = torch.tensor(shape).prod().item()
            flat_param = vector[offset:offset + size]
            params[name] = flat_param.view(shape)
            offset += size
        return params

    def forward(self, x=None, y=None):
        scale = F.softplus(self.scale_unconstrained)
        # Sample a single Normal(loc, scale)
        guide_dist = dist.Normal(self.loc, scale).to_event(1)
        sample = pyro.sample("_auto_latent", guide_dist)

        # Unpack the flat sample into model parameter shapes
        unpacked = self._unpack(sample)

        # Feed these into Pyro sample statements so they match the model
        pyro.sample("conv1.weight", dist.Delta(unpacked["conv1.weight"]).to_event(4))
        pyro.sample("conv1.bias", dist.Delta(unpacked["conv1.bias"]).to_event(1))
        pyro.sample("conv2.weight", dist.Delta(unpacked["conv2.weight"]).to_event(4))
        pyro.sample("conv2.bias", dist.Delta(unpacked["conv2.bias"]).to_event(1))
        pyro.sample("fc1.weight", dist.Delta(unpacked["fc1.weight"]).to_event(2))
        pyro.sample("fc1.bias", dist.Delta(unpacked["fc1.bias"]).to_event(1))


In [11]:
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam

In [12]:
#guide = AutoDiagonalNormal(bayesian_model)
#guide = CustomGuide(num_classes=num_classes)
guide = CustomVectorizedGuide(num_classes=num_classes)

# 2. Optimizer and SVI - increase learning rate for better convergence
optimizer = Adam({"lr": 1e-3})  # Increased from 1e-4 to 1e-3
svi = pyro.infer.SVI(model=bayesian_model,
                     guide=guide,
                     optim=optimizer,
                     loss=pyro.infer.Trace_ELBO())

In [13]:
from tqdm import tqdm

In [14]:
def train_svi(model, guide, svi, train_loader, num_epochs=10):
    # Clear parameter store only ONCE at the beginning
    pyro.clear_param_store()
    model.train()
    
    # Ensure model is on the correct device
    model.to(device)
    #guide.to(device)
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            loss = svi.step(images, labels)
            epoch_loss += loss
            num_batches += 1
            
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch+1} - ELBO Loss: {avg_loss:.4f}")


In [15]:
pyro.clear_param_store()

# Ensure model and guide are on the correct device
bayesian_model.to(device)
guide.to(device)

train_loader, test_loader = load_data(batch_size=54)
train_svi(bayesian_model, guide, svi, train_loader, num_epochs=10)

{'_auto_latent'}
Epoch 1/10: 100%|██████████| 400/400 [00:36<00:00, 11.01it/s]


Epoch 1 - ELBO Loss: 816830.5121


Epoch 2/10: 100%|██████████| 400/400 [00:14<00:00, 27.29it/s]


Epoch 2 - ELBO Loss: 755250.8649


Epoch 3/10: 100%|██████████| 400/400 [00:13<00:00, 30.22it/s]


Epoch 3 - ELBO Loss: 689854.6885


Epoch 4/10: 100%|██████████| 400/400 [00:14<00:00, 28.47it/s]


Epoch 4 - ELBO Loss: 625978.2338


Epoch 5/10: 100%|██████████| 400/400 [00:13<00:00, 30.62it/s]


Epoch 5 - ELBO Loss: 565972.3778


Epoch 6/10: 100%|██████████| 400/400 [00:15<00:00, 25.13it/s]


Epoch 6 - ELBO Loss: 511394.9917


Epoch 7/10: 100%|██████████| 400/400 [00:13<00:00, 29.45it/s]


Epoch 7 - ELBO Loss: 462949.3297


Epoch 8/10: 100%|██████████| 400/400 [00:13<00:00, 29.84it/s]


Epoch 8 - ELBO Loss: 420367.4260


Epoch 9/10: 100%|██████████| 400/400 [00:13<00:00, 30.16it/s]


Epoch 9 - ELBO Loss: 383153.0426


Epoch 10/10: 100%|██████████| 400/400 [00:13<00:00, 30.66it/s]

Epoch 10 - ELBO Loss: 350525.5477





In [16]:
# save the model
model_path = 'results_eurosat/bayesian_cnn_model_std10_cust10_epoch.pth'
torch.save(bayesian_model.state_dict(), model_path)

# save the guide
guide_path = 'results_eurosat/bayesian_cnn_guide_std10_cust10_epoch_guide.pth'
torch.save(guide.state_dict(), guide_path)

# save the pyro parameter store
pyro_param_store_path = 'results_eurosat/pyro_param_store_std10_cust10_epoch.pkl'
pyro.get_param_store().save(pyro_param_store_path)

In [17]:
for name, value in pyro.get_param_store().items():
    print(name, value.shape)

# print the total number of parameters in the model, from the value.shape
# by converting the value.shape into scalar
total_params = sum(value.numel() for value in pyro.get_param_store().values())
print(f"Total number of parameters in the model: {total_params}")

scale_unconstrained torch.Size([217546])
loc torch.Size([217546])
Total number of parameters in the model: 435092


In [18]:
for name, value in pyro.get_param_store().items():
    print(f"{name}: {value.shape}")
    #print(value)

scale_unconstrained: torch.Size([217546])
loc: torch.Size([217546])


In [19]:
# print confusion matrix
import numpy as np
from sklearn.metrics import confusion_matrix


def predict_data(model, test_loader, num_samples=10):
    model.eval()

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

            for i in range(num_samples):
                guide_trace = pyro.poutine.trace(guide).get_trace()
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    return all_labels, all_predictions

In [20]:
all_labels, all_predictions = predict_data(bayesian_model, test_loader, num_samples=10)

Evaluating: 100%|██████████| 100/100 [00:40<00:00,  2.50it/s]


In [21]:
cm = confusion_matrix(all_labels, all_predictions)

In [22]:
#print accuracy from confusion matrix
accuracy = np.trace(cm) / np.sum(cm)
print(f"Accuracy from confusion matrix: {accuracy * 100:.6f}%")

Accuracy from confusion matrix: 8.981481%


In [None]:
# plot the confusion matrix
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, classes):
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, cm[i, j],
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
            
    # make a mark to the diagonal
    plt.plot([0, cm.shape[1]-1], [0, cm.shape[0]-1], color='red', linestyle='--', linewidth=2)

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

# Plot the confusion matrix
class_names = ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial',
               'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
plot_confusion_matrix(cm, class_names)

Feature TODO:
1. Record loss after each epoch
2. Send result to GPU