In [3]:
import os
import random
import torch
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import models
import wandb
from train import evaluation, training 
from data import load_data
import numpy as np
from util import categorical_layer_factory, hadamard_layer_factory, dense_layer_factory, mixing_layer_factory

from Cirkits.cirkit.templates.region_graph import QuadTree
from Cirkits.cirkit.symbolic.circuit import Circuit
from Cirkits.cirkit.pipeline import PipelineContext

random.seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

height = 28
input_dim = 784
batch_size = 256
region_graph = QuadTree(shape=(height, height))
result_dir = 'model'
name = 'pc'
max_patience = 20
num_epochs = 3
lam = 0.
lr = 0.1
num_input_units = 8
num_sum_units = 8

hyperparameters = {'input_dim': input_dim, 
                   'lr': lr,
                   'num_epochs': num_epochs,
                   'max_patience': max_patience,
                   'batch_size': batch_size,
                   'lambda': lam,
                   'num_input_units': 8,
                   'num_sum_units': 8
                    }

#run = wandb.init(entity="rajpal906")#entity="rajpal906", project="MADE", name="unregularized", id="1", config=hyperparameters, settings=wandb.Settings(start_method="fork"))


symbolic_circuit = Circuit.from_region_graph(
    region_graph,
    num_input_units=num_input_units,
    num_sum_units=num_sum_units,
    input_factory=categorical_layer_factory,
    sum_factory=dense_layer_factory,
    prod_factory=hadamard_layer_factory,
    mixing_factory=mixing_layer_factory
)

ctx = PipelineContext(
    backend='torch',   # Choose the torch compilation backend
    fold=True,         # Fold the circuit, this is a backend-specific compilation flag
    semiring='lse-sum' # Use the (R, +, *) semiring, where + is the log-sum-exp and * is the sum
)
circuit = ctx.compile(symbolic_circuit).to(device)
pf_circuit = ctx.integrate(circuit).to(device)
model = (circuit, pf_circuit)

train_data, val_data, test_data = load_data('mnist', binarize = False)
train_loader = DataLoader(train_data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=os.cpu_count())
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count())
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count())
optimizer = torch.optim.SGD([p for p in circuit.parameters() if p.requires_grad == True], lr=lr, momentum=0.95) #torch.optim.Adam([p for p in circuit.parameters() if p.requires_grad == True], lr = lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

ModuleNotFoundError: No module named 'cirkit'

In [6]:
nll_val, bpd_val, model_best = training(name=name, result_dir=result_dir, max_patience=max_patience, num_epochs=num_epochs, 
                   model=model, optimizer=optimizer, scheduler=scheduler, training_loader=train_loader, 
                   val_loader=val_loader, device=device, lam=lam, batch_size = batch_size)
circuit, pf_circuit = model_best
#torch.save(circuit, 'models/pc_test_circuit.pt')
#torch.save(pf_circuit, 'models/pc_test_pf_circuit.pt')
#circuit = torch.load('models/pc_test_circuit.pt')
#pf_circuit = torch.load('models/pc_test_pf_circuit.pt')
#model_best = (circuit, pf_circuit)
test_nll, test_bpd = evaluation(test_loader, device, model_best=model_best)
print(f'Test NLL ={test_nll}, Test BPD = {test_bpd}')
#wandb.log({"test_bpd": test_bpd, "test_loss": test_nll})
#run.log_artifact(result_dir + '/' + name + '.model')
#run.finish()

Average test LL: 950.802
Bits per dimension: 1.7496399861924141
Epoch: 0, train nll=985.2160034179688, val nll=950.8022903645833
saved!
Average test LL: 920.619
Bits per dimension: 1.6940971064713326
Epoch: 1, train nll=913.71826171875, val nll=920.6187682291667
saved!
Average test LL: 874.754
Bits per dimension: 1.6096976926651563
Epoch: 2, train nll=859.733642578125, val nll=874.7538151041666
saved!
Average test LL: 869.612
Bits per dimension: 1.6002362537602357
FINAL LOSS: nll=869.6122100585937


In [None]:
# Delete Cirkit, reclone correctly, add submodule, write a script to replace cirkit with Cirkit.cirkit, 
# maybe dont push to github at all, just keep it cloned on both Eddie and local and run the scripts there
