In [None]:
import torch
from torch import optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from fvib import calc_K, VIB, FVIB, Deterministic
from networks import MnistFeatureExtractor
from loss import FVIBLoss, KLLoss, DistortionTaylorApproxLoss
from utils import train_FVIB, train_VIB, fvib_ib_curve
from calibration import CalibrateFVIB, split_valset, calc_logits, ECELoss

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

## Load Dataset

In [None]:
data_path = './data'
d = 10 #the number of classes
batch_size = 100

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])

trainset = torchvision.datasets.FashionMNIST(root=data_path, train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.FashionMNIST(root=data_path, train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=4)

## Train Models

In [None]:
result_path = './results'

In [None]:
n_epochs = 200
lr = 0.0001
dim=256 #dimension of z for non-FVIB model (for FVIB, dim = d-1)

In [None]:
#train FVIB
betas = [0.001, 0.1, 0.5] #Beta to record losses during training
net = MnistFeatureExtractor(1024, d-1).to(device)
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.97) # set None if not used
path = result_path + "/fvib"
train_FVIB(n_epochs, d, net, train_loader, test_loader, betas, path, optimizer, scheduler)

In [None]:
#train VIB
betas = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.]
for beta in betas:
    net = MnistFeatureExtractor(1024, 2*dim).to(device)
    vib = VIB(2*dim, d, -5, 1)
    net_optimizer = optim.Adam(net.parameters(), lr=lr)
    net_scheduler = optim.lr_scheduler.StepLR(net_optimizer, step_size=2, gamma=0.97)
    vib_optimizer = optim.Adam(vib.parameters(), lr=lr)
    vib_scheduler = optim.lr_scheduler.StepLR(vib_optimizer, step_size=2, gamma=0.97)
    path = result_path + f"/vib/beta_{beta}"
    train_VIB(n_epochs, d, net, vib, train_loader, test_loader, beta, path, net_optimizer, vib_optimizer, net_scheduler, vib_scheduler)
    #set use_dta_for_loss = True to train VIB Taylor approx
    #set sq_vib = True for sqVIB

In [None]:
#train baseline
net = MnistFeatureExtractor(1024, dim).to(device)
vib = Deterministic(dim, d).to(device)
net_optimizer = optim.Adam(net.parameters(), lr=lr)
net_scheduler = optim.lr_scheduler.StepLR(net_optimizer, step_size=2, gamma=0.97)
vib_optimizer = optim.Adam(vib.parameters(), lr=lr)
vib_scheduler = optim.lr_scheduler.StepLR(vib_optimizer, step_size=2, gamma=0.97)
path = result_path + "/base"
train_VIB(n_epochs, d, net, vib, train_loader, test_loader, 0, path, net_optimizer, vib_optimizer, net_scheduler, vib_scheduler,is_deterministic=True)

## Plot IB Curve

In [None]:
conf_after_ts = 0.997 #value of c in Confidence Tuning, set None if not used
betas = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
H_Y= np.log(d) #Entropy of Y

#load values for VIB
vib_train_I_Z_Y = []
vib_train_I_X_Z = []
vib_test_I_Z_Y = []
vib_test_I_X_Z = []
for beta in betas:
    path = result_path + f'/vib/beta_{beta}'
    train_CE = np.load(path+"/train_ce_loss.npy")
    test_CE = np.load(path+"/test_ce_loss.npy")
    train_KL = np.load(path+"/train_kl_loss.npy")
    test_KL = np.load(path+"/test_kl_loss.npy")
    vib_train_I_Z_Y.append(-train_CE[-1]+H_Y)
    vib_train_I_X_Z.append(train_KL[-1])
    vib_test_I_Z_Y.append(-test_CE[-1]+H_Y)
    vib_test_I_X_Z.append(test_KL[-1])

# evaluate FVIB with varying beta
path = result_path + '/fvib'
net = MnistFeatureExtractor(1024, d-1).to(device)
net.load_state_dict(torch.load(path+'/fe_weight.pth'))
test_CEs, test_KLs, test_DTAs, test_acc = fvib_ib_curve(betas, path, test_loader, net, 1, conf_after_ts)
train_CEs, train_KLs, train_DTAs, train_acc = fvib_ib_curve(betas, path, train_loader, net, 1, conf_after_ts)

fvib_train, vib_train, fvib_test, vib_test = (train_KLs, -train_CEs+H_Y), (vib_train_I_X_Z, vib_train_I_Z_Y), (test_KLs, -test_CEs+H_Y), (vib_test_I_X_Z, vib_test_I_Z_Y)

In [None]:
#plot curves
fig, axes = plt.subplots(1, 2, figsize=(6,2), tight_layout=True)
axes[0].set_title("Training", fontsize=13)
axes[0].plot(*vib_train, label="VIB",marker="^",   markersize=5,color="C0")
axes[0].plot(*fvib_train, label="FVIB",marker=".",   markersize=8,color="C1")
axes[1].set_title("Test", fontsize=13)
axes[1].plot(*vib_test, label="VIB",marker="^",   markersize=5,color="C0")
axes[1].plot(*fvib_test, label="FVIB",marker=".",   markersize=8,color="C1")
axes[0].set_xlim(0, fvib_train[0][0])
axes[1].set_xlim(0, fvib_test[0][0])
axes[0].legend(fontsize=12, bbox_to_anchor=(1, 0), loc='lower right', borderaxespad=1)
axes[0].set_xlabel(r'$I(X, Z)$')
axes[1].set_xlabel(r'$I(X, Z)$')
axes[0].set_ylabel(r'$I(Z, Y)$')
axes[1].set_ylabel(r'$I(Z, Y)$')
fig.savefig(result_path+'/ib_curve.pdf', bbox_inches="tight", pad_inches=0.05)

## Continuous Optimization of Beta for Calibration

In [None]:
#split validation set for calibration
num_val_data=1000
val_loader, test_loader = split_valset(testset, num_val_data)

In [None]:
#define model
net = MnistFeatureExtractor(1024, d-1).to(device)
net.load_state_dict(torch.load(result_path+'/fvib/fe_weight.pth'))
K = torch.load(result_path+'/fvib/K.pt')
model = CalibrateFVIB(net, K, num_samples=30, conf_after_ts=0.997, max_iter=50, lr=0.1).to(device)

In [None]:
# optimize beta
model.set_beta(val_loader)

In [None]:
#Validate calibration performance
eceloss = ECELoss()
logits, labels = calc_logits(model, test_loader)
_, predicted = torch.max(logits.data, 1)
acc = (predicted == labels).sum().item()/ len(labels)
ece = eceloss(logits, labels, False)[0].item()
print(f"acc:{acc*100}, ece:{ece*100}")