In [1]:
import os
import torch
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import models
import wandb
from train_pc import evaluation, training 
from data import load_data
import numpy as np
from util import categorical_layer_factory, hadamard_layer_factory, dense_layer_factory, mixing_layer_factory

from Cirkit.cirkit.templates.region_graph import QuadTree
from Cirkit.cirkit.symbolic.circuit import Circuit
from Cirkit.cirkit.pipeline import PipelineContext

np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

height = 28
input_dim = 784
batch_size = 256
region_graph = QuadTree(shape=(height, height))
result_dir = 'model'
name = 'pc'
max_patience = 20
num_epochs = 3
lam = 1.0
lr = 0.1
num_input_units = 8
num_sum_units = 8

symbolic_circuit = Circuit.from_region_graph(
    region_graph,
    num_input_units=num_input_units,
    num_sum_units=num_sum_units,
    input_factory=categorical_layer_factory,
    sum_factory=dense_layer_factory,
    prod_factory=hadamard_layer_factory,
    mixing_factory=mixing_layer_factory
)

ctx = PipelineContext(
    backend='torch',   # Choose the torch compilation backend
    fold=True,         # Fold the circuit, this is a backend-specific compilation flag
    semiring='lse-sum' # Use the (R, +, *) semiring, where + is the log-sum-exp and * is the sum
)
circuit = ctx.compile(symbolic_circuit).to(device)
pf_circuit = ctx.integrate(circuit).to(device)
model = (circuit, pf_circuit)

train_data, val_data, test_data = load_data('mnist', binarize = False)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count())
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count())
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count())
optimizer = torch.optim.Adam([p for p in circuit.parameters() if p.requires_grad == True], lr = lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

In [3]:
nll_val = training(name=name, result_dir=result_dir, max_patience=max_patience, num_epochs=num_epochs, 
                   model=model, optimizer=optimizer, scheduler=scheduler, 
                   training_loader=train_loader, val_loader=val_loader, device=device, lam=lam, batch_size = batch_size)

Average test LL: 63.554
Bits per dimension: 0.11695029445732037
Epoch: 0, train nll=56.70426940917969, val nll=63.553993225097656
saved!
Average test LL: 63.206
Bits per dimension: 0.11630964654767778
Epoch: 1, train nll=51.171173095703125, val nll=63.20584760393415
saved!
Average test LL: 63.156
Bits per dimension: 0.11621876683685042
Epoch: 2, train nll=46.05095672607422, val nll=63.15646107991537
saved!


In [4]:
#TODO: validate losses, figure out how to save models, integrate wandb, 
# validate, change translations, hyperparameter optimization on Eddie, sample and validate