In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch

import pickle

In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda")

In [4]:
class BayesianCNNSingleFC(PyroModule):
    def __init__(self, num_classes):
        super().__init__()

        prior_mu = 0.
        #prior_sigma = 0.1 #accuracy 13.203704% 2 epochs
        #prior_sigma = 1. #accuracy 31% 2 epochs
        prior_sigma = torch.tensor(10., device=device) #accuracy 45% 10 epochs
        #prior_sigma = 100 #accuracy 21% 10 epochs

        self.conv1 = PyroModule[nn.Conv2d](3, 32, kernel_size=5, stride=1, padding=2)
        self.conv1.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([32, 3, 5, 5]).to_event(4))
        self.conv1.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([32]).to_event(1))

        self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=5, stride=1, padding=2) #initially padding=1 kernel_size=3, without stride
        self.conv2.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([64, 32, 5, 5]).to_event(4))
        self.conv2.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([64]).to_event(1))

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = PyroModule[nn.Linear](64 * 16 * 16, num_classes)
        self.fc1.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([num_classes, 64 * 16 * 16]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([num_classes]).to_event(1))

    def forward(self, x, y=None):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        logits = self.fc1(x)
        
        # THIS IS THE MISSING PIECE: Define the likelihood
        if y is not None:
            with pyro.plate("data", x.shape[0]):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        
        return logits

In [5]:
def load_data(batch_size=54):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1809, 0.1331, 0.1137])
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=False)

    torch.manual_seed(42)
    
    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

In [6]:
from pyro.infer.autoguide import AutoDiagonalNormal

In [7]:
num_classes = 10
bayesian_model = BayesianCNNSingleFC(num_classes=num_classes).to(device)

In [9]:
model_path = 'results_eurosat/bayesian_cnn_model_std10_10_epoch.pth'
guide_path = 'results_eurosat/bayesian_cnn_guide_std10_10_epoch_guide.pth'
pyro_param_store_path = 'results_eurosat/pyro_param_store_std10_10_epoch.pkl'

guide = AutoDiagonalNormal(bayesian_model).to(device)
#guide.load_state_dict(torch.load(guide_path))

pyro.get_param_store().set_state(torch.load(pyro_param_store_path,weights_only=False))

In [10]:
from tqdm import tqdm

In [17]:
# print confusion matrix
import numpy as np
from sklearn.metrics import confusion_matrix


def predict_data(model, test_loader, num_samples=10):
    model.eval()

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

            for i in range(num_samples):
                guide_trace = pyro.poutine.trace(guide).get_trace(images)
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    return all_labels, all_predictions

In [18]:
train_loader, test_loader = load_data(batch_size=54)

In [19]:
for name, value in pyro.get_param_store().items():
    print(f"{name}: {value.shape}")
    print(value)

AutoDiagonalNormal.loc: torch.Size([217546])
Parameter containing:
tensor([ 3.2559, -2.3985, -1.1288,  ..., -3.3111,  6.0583,  1.8354],
       device='cuda:0', requires_grad=True)
AutoDiagonalNormal.scale: torch.Size([217546])
tensor([0.1000, 0.0922, 0.0985,  ..., 1.0809, 1.0090, 1.1148], device='cuda:0',
       grad_fn=<SoftplusBackward0>)


In [20]:
all_labels, all_predictions = predict_data(bayesian_model, test_loader, num_samples=10)

Evaluating: 100%|██████████| 100/100 [00:35<00:00,  2.80it/s]


In [21]:
cm = confusion_matrix(all_labels, all_predictions)

In [22]:
#print accuracy from confusion matrix
accuracy = np.trace(cm) / np.sum(cm)
print(f"Accuracy from confusion matrix: {accuracy * 100:.6f}%")

Accuracy from confusion matrix: 60.166667%
