In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch

import pickle

In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pandas as pd

In [4]:
device = torch.device("cuda")

In [5]:
from torchvision.datasets import ImageFolder

In [6]:
dataset = ImageFolder(
    root="data/shipsnet/foldered",
    transform=transforms.ToTensor()
)

In [7]:
loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
    num_workers=1)

In [8]:
shipsnet_mean = [0.4119, 0.4243, 0.3724]
shipsnet_std = [0.1899, 0.1569, 0.1515]

def load_data(batch_size=16):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=shipsnet_mean, 
                             std=shipsnet_std)
    ])

    dataset = ImageFolder(
    root="data/shipsnet/foldered",
    transform=transform
    )

    torch.manual_seed(42)

    #train_size = int(0.8 * len(dataset))
    #test_size = len(dataset) - train_size
    #train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    with open('datasplit/shipsnet_split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader, train_dataset, test_dataset

In [9]:
train_loader, test_loader, train_ds, test_ds = load_data(16)

In [12]:
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

class BayesShipsCNN(PyroModule):
    def __init__(self, prior_std=1.0):
        super().__init__()
        # build your conv‐feature extractor
        self.features = PyroModule[nn.Sequential](
            PyroModule[nn.Conv2d](3, 32, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            PyroModule[nn.Conv2d](32, 64, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            PyroModule[nn.Conv2d](64,128, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1,1))
        )

        # attach Gaussian(0, prior_std²) priors to **any** Conv2d you find
        for layer in self.features:
            if isinstance(layer, nn.Conv2d):
                # weight: shape = [out_c, in_c, kH, kW] → 4 event dims
                layer.weight = PyroSample(
                    dist.Normal(0., prior_std)
                       .expand(layer.weight.shape)
                       .to_event(4)
                )
                # bias: shape = [out_c] → 1 event dim
                layer.bias = PyroSample(
                    dist.Normal(0., prior_std)
                       .expand(layer.bias.shape)
                       .to_event(1)
                )

        # likewise for your classifier
        self.classifier = PyroModule[nn.Sequential](
            nn.Flatten(),
            PyroModule[nn.Linear](128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            PyroModule[nn.Linear](256, 2),
        )
        for layer in self.classifier:
            if isinstance(layer, nn.Linear):
                layer.weight = PyroSample(
                    dist.Normal(0., prior_std)
                       .expand(layer.weight.shape)
                       .to_event(2)
                )
                layer.bias = PyroSample(
                    dist.Normal(0., prior_std)
                       .expand(layer.bias.shape)
                       .to_event(1)
                )

    def forward(self, x, y=None):
        x = self.features(x)
        logits = self.classifier(x)
        with pyro.plate("data", x.size(0)):
            pyro.sample("obs",
                        dist.Categorical(logits=logits),
                        obs=y)
        return logits


In [14]:
import torch
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

bcnn = BayesShipsCNN(prior_std=0.1)
optimizer = Adam({"lr": 1e-3})
svi = SVI(bcnn,             # model
          bcnn,             # guide
          optimizer,
          loss=Trace_ELBO())

num_epochs = 10
for epoch in range(1, num_epochs+1):
    bcnn.train()     # ensure dropout, etc., are in train mode
    epoch_loss = 0.0
    correct = 0
    total = 0

    for x_batch, y_batch in train_loader:
        # --- SVI gradient step ---
        epoch_loss += svi.step(x_batch, y_batch)

        # --- instantaneous “predict” for accuracy ---
        # disable grad / pyro tracing for speed & determinism
        with torch.no_grad():
            # skip the obs‐sampling in forward by passing y=None
            logits = bcnn(x_batch, y=None)
            preds = logits.argmax(dim=1)
        correct += (preds == y_batch).sum().item()
        total += y_batch.size(0)

    avg_loss = epoch_loss / len(train_loader.dataset)
    train_acc = correct / total

    print(f"Epoch {epoch:3d}  loss = {avg_loss:.4f}  train_acc = {train_acc:.4f}")

{'obs'}


Epoch   1  loss = -0.0025  train_acc = 0.5147


KeyboardInterrupt: 