In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split

import pickle
from tqdm import tqdm
import copy

In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from pyro.infer.autoguide import AutoDiagonalNormal

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import numpy as np
from sklearn.metrics import confusion_matrix

## Defining Model and Loading Model Training Result

In [4]:
device = torch.device("cuda")

In [5]:
class BayesianCNNSingleFC(PyroModule):
    def __init__(self, num_classes):
        super().__init__()

        prior_mu = 0.
        prior_sigma = torch.tensor(10., device=device)

        self.conv1 = PyroModule[nn.Conv2d](3, 32, kernel_size=5, stride=1, padding=2)
        self.conv1.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([32, 3, 5, 5]).to_event(4))
        self.conv1.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([32]).to_event(1))

        self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=5, stride=1, padding=2) #initially padding=1 kernel_size=3, without stride
        self.conv2.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([64, 32, 5, 5]).to_event(4))
        self.conv2.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([64]).to_event(1))

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = PyroModule[nn.Linear](64 * 16 * 16, num_classes)
        self.fc1.weight = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([num_classes, 64 * 16 * 16]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(prior_mu, prior_sigma).expand([num_classes]).to_event(1))

    def forward(self, x, y=None):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        logits = self.fc1(x)
        
        if y is not None:
            with pyro.plate("data", x.shape[0]):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        
        return logits

In [6]:
def load_data(batch_size=54):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1809, 0.1331, 0.1137])
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=False)

    torch.manual_seed(42)
    
    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

In [7]:
def predict_data_probs(model, test_loader, num_samples=10):
    model.eval()

    all_labels = []
    all_predictions = []
    all_logits = []
    all_probs = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

            for i in range(num_samples):
                guide_trace = pyro.poutine.trace(guide).get_trace(images)
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())
            all_logits.extend(avg_logits.cpu().numpy())
            all_probs.extend(F.softmax(avg_logits, dim=1).cpu().numpy())

    return all_labels, all_predictions, all_logits, all_probs

In [8]:
#instantiate the model
num_classes = 10
bayesian_model = BayesianCNNSingleFC(num_classes=num_classes).to(device)

## Load Trained Model (Before Bitflip)

In [9]:
model_path = 'results_eurosat/bayesian_cnn_model_std10_100_epoch.pth'
guide_path = 'results_eurosat/bayesian_cnn_guide_std10_100_epoch_guide.pth'
pyro_param_store_path = 'results_eurosat/pyro_param_store_std10_100_epoch.pkl'

guide = AutoDiagonalNormal(bayesian_model).to(device)

pyro.get_param_store().set_state(torch.load(pyro_param_store_path,weights_only=False))

original_param_store = {}

for name, value in pyro.get_param_store().items():
    print(f"{name}: {value.shape}")
    original_param_store[name] = torch.tensor(value.data, requires_grad=value.requires_grad)

AutoDiagonalNormal.loc: torch.Size([217546])
AutoDiagonalNormal.scale: torch.Size([217546])


  original_param_store[name] = torch.tensor(value.data, requires_grad=value.requires_grad)


In [10]:
train_loader, test_loader = load_data(batch_size=54)

In [11]:
for name, value in pyro.get_param_store().items():
    print(f"{name}: {value.shape}")
    print(value)

AutoDiagonalNormal.loc: torch.Size([217546])
Parameter containing:
tensor([ 3.1483, -2.4763, -1.0711,  ..., -2.4452,  4.6454,  1.5156],
       device='cuda:0', requires_grad=True)
AutoDiagonalNormal.scale: torch.Size([217546])
tensor([0.0454, 0.0385, 0.0440,  ..., 7.7091, 6.1614, 6.6950], device='cuda:0',
       grad_fn=<SoftplusBackward0>)


In [12]:
all_labels, all_predictions, all_logits, all_probs = predict_data_probs(bayesian_model, test_loader, num_samples=10)

Evaluating: 100%|██████████| 100/100 [00:35<00:00,  2.83it/s]


In [13]:
cm = confusion_matrix(all_labels, all_predictions)

In [14]:
#print accuracy from confusion matrix
accuracy = np.trace(cm) / np.sum(cm)
print(f"Accuracy from confusion matrix: {accuracy * 100:.6f}%")

Accuracy from confusion matrix: 74.907407%


## Bitflip Process

In [15]:
from bitflip import bitflip_float32

In [16]:
param_store = pyro.get_param_store()

In [17]:
def change_item(param_store, location_index, new_value):
    pyro.get_param_store()[param_store][location_index] = new_value

    return pyro.get_param_store()[param_store]

def run_seu_autodiagonal_normal(location_index: int, bit_i: int, parameter_name: str="loc"):
    """Perform a bitflip at index i across every variable in the AutoDiagonalNormal guide"""

    assert bit_i in range(0, 33)
    assert parameter_name in ["loc", "scale"]
    assert location_index in range(0, len(pyro.get_param_store()[f"AutoDiagonalNormal.{parameter_name}"]))

    if parameter_name == "loc":
        param_store_name = "AutoDiagonalNormal.loc"
    elif parameter_name == "scale":
        param_store_name = "AutoDiagonalNormal.scale"

    bayesian_model.to(device)
    bayesian_model.eval()

    with torch.no_grad():
        param_dict = {}

        for name, value in pyro.get_param_store().items():
            #print(f"{name}: {value.shape}")
            #print(value)
            param_dict[name] = value.cpu().detach().numpy()

        tensor_cpu = param_dict[param_store_name]

        original_val = tensor_cpu[0]
        seu_val = bitflip_float32(original_val, bit_i)


        print(f"Original value: {original_val}, SEU value: {seu_val}")

        # Get the parameter
        param = pyro.get_param_store().get_param(param_store_name)

        # Modify it safely by creating a new tensor
        new_param = param.clone()
        new_param[location_index] = seu_val  # New Value

        # Update the parameter store
        if parameter_name == "loc":
            pyro.get_param_store().__setitem__(param_store_name, new_param) # 74%
            #param_store[param_store_name].data.copy_(change_item(param_store_name, location_index, seu_val)) #25%
            #pyro.get_param_store()[param_store_name].data[location_index] = seu_val # 25%
        elif parameter_name == "scale":
            pyro.get_param_store().__setitem__(param_store_name, new_param) #10%
            #pyro.get_param_store()[param_store_name].data[location_index] = seu_val

In [18]:
param_store["AutoDiagonalNormal.scale"]

tensor([0.0454, 0.0385, 0.0440,  ..., 7.7091, 6.1614, 6.6950], device='cuda:0',
       grad_fn=<SoftplusBackward0>)

In [19]:
param_store["AutoDiagonalNormal.loc"]

Parameter containing:
tensor([ 3.1483, -2.4763, -1.0711,  ..., -2.4452,  4.6454,  1.5156],
       device='cuda:0', requires_grad=True)

In [20]:
#run_seu_autodiagonal_normal(location_index= 0, bit_i=2, parameter_name="loc")
run_seu_autodiagonal_normal(location_index= 0, bit_i=1, parameter_name="scale")

Original value: 0.04540996998548508, SEU value: 1.5452212068469636e+37


In [21]:
#param_store["AutoDiagonalNormal.loc"]

In [22]:
param_store["AutoDiagonalNormal.loc"]

Parameter containing:
tensor([ 3.1483, -2.4763, -1.0711,  ..., -2.4452,  4.6454,  1.5156],
       device='cuda:0', requires_grad=True)

In [23]:
param_store["AutoDiagonalNormal.scale"]

tensor([1.5452e+37, 3.8508e-02, 4.3955e-02,  ..., 7.7091e+00, 6.1614e+00,
        6.6950e+00], device='cuda:0', grad_fn=<SoftplusBackward0>)

## After Bitflip

In [24]:
guide = AutoDiagonalNormal(bayesian_model).to(device)

In [25]:
changed = not torch.equal(pyro.get_param_store()["AutoDiagonalNormal.loc"], #AFTER 
                          original_param_store["AutoDiagonalNormal.loc"], #BEFORE
                          )
print("Weights changed:", changed)

Weights changed: False


In [26]:
after_all_labels, after_all_predictions, after_all_logits, after_all_probs = predict_data_probs(bayesian_model, test_loader, num_samples=10)

Evaluating: 100%|██████████| 100/100 [00:08<00:00, 11.75it/s]


In [27]:
after_cm = confusion_matrix(after_all_labels, after_all_predictions)

In [28]:
after_accuracy = np.trace(after_cm) / np.sum(after_cm)
print(f"Accuracy from confusion matrix: {after_accuracy * 100:.6f}%")

Accuracy from confusion matrix: 11.111111%


In [29]:
#print the difference in accuracy
print(f"Accuracy difference: {(after_accuracy - accuracy)*100:.6f}%")

Accuracy difference: -63.796296%
