# BPD test examples

For the ImageNet examples, we use the standard datasets downloaded from https://image-net.org/download-images.php (login needed).
* Val(32x32)
* Val(64x64)

In [1]:
import os
import time
import torch
from torch import optim
from torch.utils import data
import torch.nn as nn
from utils import *
from model import *
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
from scipy.io import loadmat
from torch.optim import lr_scheduler
from tqdm import tqdm
import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion  = lambda real, fake : discretized_mix_logistic_uniform(real, fake, alpha=0.0001)
rescaling     = lambda x : (x - .5) * 2.
rescaling_inv = lambda x : .5 * x  + .5

### Model trained on CIFAR10 and test on CIFAR10

In [6]:
testset = torchvision.datasets.CIFAR10(root='../../data/cifar10', train=False, download=False, transform=transforms.ToTensor())
test = data.DataLoader(testset, batch_size=1000, shuffle=True, num_workers=3)
def test_bpd(net):
    with torch.no_grad():
        net.eval()
        bpd_cifar_sum=0.
        for i, (images, labels) in enumerate(test):
            images = rescaling(images).to(device)
            output = net(images)
            loss = criterion(images, output).item()
            bpd_cifar_sum+=loss/(np.log(2.)*(1000*32*32*3))
        bpd_cifar=bpd_cifar_sum/10
        print('bpd_cifar',bpd_cifar)
        
net = LocalPixelCNN(res_num=0, in_kernel = 7,  in_channels=3, channels=256, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/rs0_cifar_h3.pt'))
print('rs=0')
test_bpd(net)
net = LocalPixelCNN(res_num=1, in_kernel = 7,  in_channels=3, channels=256, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/rs1_cifar_h3.pt'))
print('rs=1')
test_bpd(net)
net = LocalPixelCNN(res_num=3, in_kernel = 7,  in_channels=3, channels=256, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/rs3_cifar_h3.pt'))
print('rs=3')
test_bpd(net)

rs=0
bpd_cifar 3.3833506457358866
rs=1
bpd_cifar 3.281978964854339
rs=3
bpd_cifar 3.2469579184385937


### Model trained on Imagenet32 and test on Imagenet32

In [7]:
data_file = '/mnt/data/img32/val_data'
size = os.path.getsize('/mnt/data/img32/val_data') 
print('Size of test set is', size, 'bytes')
d = unpickle(data_file)
x = d['data'].reshape(-1,3,32,32)
test_img32=torch.tensor(x).float()[:,:,:,:]/255.
print('shape of test set',test_img32.size())

def test_bpd(net):
    with torch.no_grad():
        net.eval()
        bpd_img32_sum=0.
        for i in range(0,50):
            images=rescaling(test_img32[i*1000:(i+1)*1000]).to(device)
            output = net(images)
            loss = criterion(images, output).item()
            bpd_img32_sum+=loss/(np.log(2.)*(1000*32*32*3))
        bpd_img32=bpd_img32_sum/50

        print('bpd_img32',bpd_img32)
        
net = LocalPixelCNN(res_num=0, in_kernel = 7,  in_channels=3, channels=256, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/imgnet32_rs0_h3.pt'))
print('rs=0')
test_bpd(net)
net = LocalPixelCNN(res_num=1, in_kernel = 7,  in_channels=3, channels=256, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/imgnet32_rs1_h3.pt'))
print('rs=1')
test_bpd(net)
net = LocalPixelCNN(res_num=3, in_kernel = 7,  in_channels=3, channels=256, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/imgnet32_rs3_h3.pt'))
print('rs=3')
test_bpd(net)

Size of test set is 153737544 bytes
shape of test set torch.Size([50000, 3, 32, 32])
rs=0
bpd_img32 3.9364352346796125
rs=1
bpd_img32 3.848284853541681
rs=3
bpd_img32 3.8112208203916493


### Model trained on Imagenet32 and test on Imagenet64

In [5]:
data_file = '/mnt/data/Imagenet64_val_npz/val_data.npz'
size = os.path.getsize('/mnt/data/Imagenet64_val_npz/val_data.npz') 
print('Size of test set is', size, 'bytes')
d = np.load(data_file)
x = d['data'].reshape(-1,3,64,64)
test_img32=torch.tensor(x).float()[:,:,:,:]/255.
print('shape of test set',test_img32.size())

def test_bpd(net):
    with torch.no_grad():
        net.eval()
        bpd_img32_sum=0.
        for i in range(0,500):
            images=rescaling(test_img32[i*100:(i+1)*100]).to(device)
            output = net(images)
            loss = criterion(images, output).item()
            bpd_img32_sum+=loss/(np.log(2.)*(100*64*64*3))
        bpd_img32=bpd_img32_sum/500

        print('bpd_img64',bpd_img32)
        
net = LocalPixelCNN(res_num=0, in_kernel = 7,  in_channels=3, channels=256, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/imgnet32_rs0_h3.pt'))
print('rs=0')
test_bpd(net)
net = LocalPixelCNN(res_num=1, in_kernel = 7,  in_channels=3, channels=256, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/imgnet32_rs1_h3.pt'))
print('rs=1')
test_bpd(net)
net = LocalPixelCNN(res_num=3, in_kernel = 7,  in_channels=3, channels=256, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/imgnet32_rs3_h3.pt'))
print('rs=3')
test_bpd(net)

Size of test set is 529393262 bytes
shape of test set torch.Size([50000, 3, 64, 64])
rs=0
bpd_img64 3.6325303754905924
rs=1
bpd_img64 3.549048025647982
rs=3
bpd_img64 3.5152607124355035
