In [1]:
import pyro
from pyro.distributions import Normal, Categorical
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

import os

# third-party library
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

from torchvision.datasets import ImageFolder
import torch.utils.data as data
import torchvision
from torchvision import transforms
from torchvision import datasets

import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

import numpy as np

In [2]:
def load_images(image_size=(28,28), batch_size=128, root="../datasets/MainImageFolder"):

    transform = transforms.Compose([
                    transforms.Resize(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_set = datasets.ImageFolder(root=root, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    return train_loader 

In [3]:
train_data = load_images(root='./state-farm-distracted-driver-detection/imgs/train/')

In [4]:
class NN(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        output = self.fc1(x)
        output = F.relu(output)
        output = self.out(output)
        return output

In [5]:
net = NN(28*28*3, 1024, 10)
net.fc1.weight.shape

torch.Size([1024, 2352])

In [6]:
log_softmax = nn.LogSoftmax(dim=1)

In [7]:
def model(x_data, y_data):
    
    fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
    fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))
    
    outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
    outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))
    
    priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  'out.weight': outw_prior, 'out.bias': outb_prior}
    # lift module parameters to random variables sampled from the priors
    lifted_module = pyro.random_module("module", net, priors)
    # sample a regressor (which also samples w and b)
    lifted_reg_model = lifted_module()
    
    lhat = log_softmax(lifted_reg_model(x_data))
    
    pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

In [8]:
softplus = torch.nn.Softplus()

def guide(x_data, y_data):
    
    # First layer weight distribution priors
    fc1w_mu = torch.randn_like(net.fc1.weight)
    fc1w_sigma = torch.randn_like(net.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    
    # First layer bias distribution priors
    fc1b_mu = torch.randn_like(net.fc1.bias)
    fc1b_sigma = torch.randn_like(net.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
    
    # Output layer weight distribution priors
    outw_mu = torch.randn_like(net.out.weight)
    outw_sigma = torch.randn_like(net.out.weight)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
    outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)
    
    # Output layer bias distribution priors
    outb_mu = torch.randn_like(net.out.bias)
    outb_sigma = torch.randn_like(net.out.bias)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
    outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
    priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}
    
    lifted_module = pyro.random_module("module", net, priors)
    
    return lifted_module()

In [9]:
optim = Adam({"lr": 0.001})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

In [None]:
num_iterations = 100
loss = 0

for j in range(num_iterations):
    loss = 0
    for batch_id, data in enumerate(train_data):
        # calculate the loss and take a gradient step
        loss += svi.step(data[0].view(-1,28*28*3), data[1])
    normalizer_train = len(train_data.dataset)
    total_epoch_loss_train = loss / normalizer_train
    
    print("Epoch ", j, " Loss ", total_epoch_loss_train)

Epoch  0  Loss  18520.7409996304
Epoch  1  Loss  17712.08378991156
Epoch  2  Loss  16999.36978667979
Epoch  3  Loss  16348.503787259182
Epoch  4  Loss  15767.532887493293
Epoch  5  Loss  15181.636984776991
Epoch  6  Loss  14599.644741889455
Epoch  7  Loss  14024.66643613596
Epoch  8  Loss  13455.147285065905
Epoch  9  Loss  12885.748798687118
Epoch  10  Loss  12296.385455903222
Epoch  11  Loss  11711.58268431081
Epoch  12  Loss  11119.172782433232
Epoch  13  Loss  10523.901676295125
Epoch  14  Loss  9927.232426880239
Epoch  15  Loss  9335.04536436638
Epoch  16  Loss  8747.920490177537
Epoch  17  Loss  8170.842604312478
Epoch  18  Loss  7604.387424177074
Epoch  19  Loss  7052.335863050396
Epoch  20  Loss  6517.800169902915
Epoch  21  Loss  5998.097869288441
Epoch  22  Loss  5503.5296583185445
Epoch  23  Loss  5028.672856087307
Epoch  24  Loss  4577.295961678432
Epoch  25  Loss  4153.741573238964
Epoch  26  Loss  3754.464536187843
Epoch  27  Loss  3380.5703786841
Epoch  28  Loss  3037.05

In [23]:
def load_test_images(image_size=(28,28), batch_size=128, root="../datasets/MainImageFolder"):

    transform = transforms.Compose([
                    transforms.Resize(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_set = datasets.ImageFolder(root=root, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    return train_loader 

test_data = load_test_images(root='./state-farm-distracted-driver-detection/imgs/test_jpg/')
# test_data

In [36]:
num_samples = 10
def predict(x):
    sampled_models = [guide(None, None) for _ in range(num_samples)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return np.argmax(mean.numpy(), axis=1)

print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(train_data):
    images, labels = data
    predicted = predict(images.view(-1,28*28*3))
    total += labels.size(0)
#     print((predicted))
#     print(labels.numpy())
    correct += (predicted == labels.numpy()).sum().item()
    break
print("accuracy: %d %%" % (100 * correct / total))

Prediction when network is forced to predict
accuracy: 16 %
