In [1]:
import os
import random
import torch
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import models
import wandb
from train import evaluation, training 
from data import load_data
import numpy as np
from util import categorical_layer_factory, hadamard_layer_factory, dense_layer_factory, mixing_layer_factory

from Cirkit.cirkit.templates.region_graph import QuadTree
from Cirkit.cirkit.symbolic.circuit import Circuit
from Cirkit.cirkit.pipeline import PipelineContext

random.seed(42)
np.random.seed(42)
os.environ['WANDB_NOTEBOOK_NAME'] = 'hyperparameter_optimization.ipynb'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

result_dir = 'models'
if not(os.path.exists(result_dir)):
    os.mkdir(result_dir)
name = 'pc'#Change to regularized

sweep_config = {
    'method': 'grid'
    }
metric = {
'name': 'test_bpd',
'goal': 'minimize'   
}

sweep_config['metric'] = metric

#TODO: add momentum?
parameters_dict = {
'input_dim': {
    'value': 784
    },
'lam': {
    'values': [0.1, 0.5, 1.0]
    },
'num_epochs': {
    'value': 1
    },
'lr': {
    'values': [1e-1, 1e-2, 1e-3]
    },
'batch_size': {
    'values': [64, 128, 256]
    },
'num_input_units': {
    'value': 8
    },
'num_sum_units': {
    'value': 8
    },
'max_patience': {
    'value': 30 # No patience for now, add momentum?
    },
}


sweep_config['parameters'] = parameters_dict
sweep_id = wandb.sweep(sweep_config, project="pc_hyperparameter_optimization")

def hyperparameter_sweep(config=None):
    with wandb.init(config=config):
        config = wandb.config
        train_data, val_data, test_data = load_data('mnist', binarize = False)
        train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, num_workers=os.cpu_count())
        val_loader = DataLoader(val_data, batch_size=config.batch_size, shuffle=False, num_workers=os.cpu_count())
        test_loader = DataLoader(test_data, batch_size=config.batch_size, shuffle=False, num_workers=os.cpu_count())
        region_graph = QuadTree(shape=(28, 28))
        symbolic_circuit = Circuit.from_region_graph(region_graph,
                                                    num_input_units=config.num_input_units,
                                                    num_sum_units=config.num_sum_units,
                                                    input_factory=categorical_layer_factory,
                                                    sum_factory=dense_layer_factory,
                                                    prod_factory=hadamard_layer_factory,
                                                    mixing_factory=mixing_layer_factory)

        ctx = PipelineContext(
            backend='torch',   # Choose the torch compilation backend
            fold=True,         # Fold the circuit, this is a backend-specific compilation flag
            semiring='lse-sum' # Use the (R, +, *) semiring, where + is the log-sum-exp and * is the sum
        )
        circuit = ctx.compile(symbolic_circuit).to(device)
        pf_circuit = ctx.integrate(circuit).to(device)
        model = (circuit, pf_circuit)
        optimizer = torch.optim.SGD([p for p in circuit.parameters() if p.requires_grad == True], lr=config.lr, momentum=0.95)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
        _, _, model_best = training(name=name, result_dir=result_dir, max_patience=config.max_patience, num_epochs=config.num_epochs, 
                   model=model, optimizer=optimizer, scheduler=scheduler, training_loader=train_loader, 
                   val_loader=val_loader, device=device, lam=config.lam, batch_size = config.batch_size)
        test_nll, test_bpd = evaluation(test_loader, device, model_best=model_best)
        wandb.log({"test_bpd": test_bpd})

wandb.agent(sweep_id, hyperparameter_sweep)

Create sweep with ID: spvvkf25
Sweep URL: https://wandb.ai/rajpal906/pc_hyperparameter_optimization/sweeps/spvvkf25


[34m[1mwandb[0m: Agent Starting Run: milzhmes with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.1
[34m[1mwandb[0m: 	lr: 0.1
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8
[34m[1mwandb[0m: Currently logged in as: [33ms2592586[0m ([33mrajpal906[0m). Use [1m`wandb login --relogin`[0m to force relogin


Average test LL: 844.392
Bits per dimension: 1.55382665128959
Epoch: 0, train nll=864.5386352539062, val nll=844.3919609375
saved!
Average test LL: 840.470
Bits per dimension: 1.5466101545366937
FINAL LOSS: nll=840.4703189453126


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.54661


[34m[1mwandb[0m: Agent Starting Run: 9urzb652 with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.1
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 1021.116
Bits per dimension: 1.8790298467176378
Epoch: 0, train nll=1070.166748046875, val nll=1021.1162845052083
saved!
Average test LL: 1023.718
Bits per dimension: 1.8838176781500773
FINAL LOSS: nll=1023.7181232421875


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.88382


[34m[1mwandb[0m: Agent Starting Run: t4bcpunb with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.1
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 3203.703
Bits per dimension: 5.895365560774133
Epoch: 0, train nll=3219.151611328125, val nll=3203.7031171875
saved!
Average test LL: 3203.878
Bits per dimension: 5.895686550788579
FINAL LOSS: nll=3203.877551953125


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,5.89569


[34m[1mwandb[0m: Agent Starting Run: s8mqbuz0 with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.5
[34m[1mwandb[0m: 	lr: 0.1
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 846.055
Bits per dimension: 1.556887087756272
Epoch: 0, train nll=817.3111572265625, val nll=846.0550859375
saved!
Average test LL: 842.218
Bits per dimension: 1.5498265479373825
FINAL LOSS: nll=842.2181952148437


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.54983


[34m[1mwandb[0m: Agent Starting Run: khzlui5f with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.5
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 1021.262
Bits per dimension: 1.879298699956081
Epoch: 0, train nll=1084.68310546875, val nll=1021.26238671875
saved!
Average test LL: 1026.825
Bits per dimension: 1.889534454289933
FINAL LOSS: nll=1026.8247759765625


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded\r'), FloatProgress(value=0.20465116279069767, max=1.…

0,1
test_bpd,▁

0,1
test_bpd,1.88953


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 32p8ddp5 with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.5
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 3206.512
Bits per dimension: 5.900533799537878
Epoch: 0, train nll=3223.95849609375, val nll=3206.5116796875
saved!
Average test LL: 3207.234
Bits per dimension: 5.901863837939994
FINAL LOSS: nll=3207.234458984375


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.


0,1
test_bpd,▁

0,1
test_bpd,5.90186


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: kdtkavyx with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 1
[34m[1mwandb[0m: 	lr: 0.1
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 856.086
Bits per dimension: 1.575344991777448
Epoch: 0, train nll=873.740234375, val nll=856.0856165364584
saved!
Average test LL: 853.649
Bits per dimension: 1.5708610826420697
FINAL LOSS: nll=853.6489375


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.57086


[34m[1mwandb[0m: Agent Starting Run: olr94dyu with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 1
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 1026.145
Bits per dimension: 1.8882841843596108
Epoch: 0, train nll=992.9417114257812, val nll=1026.1453450520833
saved!
Average test LL: 1031.632
Bits per dimension: 1.8983802460188217
FINAL LOSS: nll=1031.631821484375


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.89838


[34m[1mwandb[0m: Agent Starting Run: n2d6rq8k with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 1
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 3210.495
Bits per dimension: 5.90786377106342
Epoch: 0, train nll=3220.42626953125, val nll=3210.4949869791667
saved!
Average test LL: 3210.117
Bits per dimension: 5.907168294735436
FINAL LOSS: nll=3210.11704609375


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,5.90717


[34m[1mwandb[0m: Agent Starting Run: ofvqtmgw with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.1
[34m[1mwandb[0m: 	lr: 0.1
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 919.904
Bits per dimension: 1.6927814432168602
Epoch: 0, train nll=933.9168090820312, val nll=919.90380078125
saved!
Average test LL: 915.395
Bits per dimension: 1.684484783325975
FINAL LOSS: nll=915.3951685546875


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.68448


[34m[1mwandb[0m: Agent Starting Run: d4zdl0en with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.1
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 1179.482
Bits per dimension: 2.1704504532160374
Epoch: 0, train nll=1183.0758056640625, val nll=1179.4822239583334
saved!
Average test LL: 1179.570
Bits per dimension: 2.170611064072646
FINAL LOSS: nll=1179.569504296875


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,2.17061


[34m[1mwandb[0m: Agent Starting Run: ilsl3w4s with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.1
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 3789.037
Bits per dimension: 6.972480285713149
Epoch: 0, train nll=3790.16650390625, val nll=3789.0367604166668
saved!
Average test LL: 3789.092
Bits per dimension: 6.972581550775217
FINAL LOSS: nll=3789.091790625


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,6.97258


[34m[1mwandb[0m: Agent Starting Run: 2ryhc3uh with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.5
[34m[1mwandb[0m: 	lr: 0.1
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 919.329
Bits per dimension: 1.6917232555212602
Epoch: 0, train nll=937.5757446289062, val nll=919.3287526041667
saved!
Average test LL: 916.501
Bits per dimension: 1.6865205245497663
FINAL LOSS: nll=916.50144609375


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.68652


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: oxk48vx2 with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.5
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 1180.663
Bits per dimension: 2.172623071267564
Epoch: 0, train nll=1193.571533203125, val nll=1180.6628841145832
saved!
Average test LL: 1182.473
Bits per dimension: 2.175953614590032
FINAL LOSS: nll=1182.472792578125


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,2.17595


[34m[1mwandb[0m: Agent Starting Run: efti0fdm with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.5
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168429633330056, max=1.0…

Average test LL: 3789.869
Bits per dimension: 6.974011922412596
Epoch: 0, train nll=3796.458740234375, val nll=3789.86909375
saved!
Average test LL: 3790.395
Bits per dimension: 6.974980152041998
FINAL LOSS: nll=3790.39525625


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,6.97498


[34m[1mwandb[0m: Agent Starting Run: ujrjd5mz with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 1
[34m[1mwandb[0m: 	lr: 0.1
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 926.360
Bits per dimension: 1.7046616263859908
Epoch: 0, train nll=931.0908203125, val nll=926.3598177083334
saved!
Average test LL: 922.403
Bits per dimension: 1.6973811805326697
FINAL LOSS: nll=922.4034240234375


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.69738


[34m[1mwandb[0m: Agent Starting Run: se5ymi4s with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 1
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 1182.025
Bits per dimension: 2.17512892342754
Epoch: 0, train nll=1166.64306640625, val nll=1182.0246328125
saved!
Average test LL: 1181.733
Bits per dimension: 2.174593169498002
FINAL LOSS: nll=1181.733489453125


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,2.17459


[34m[1mwandb[0m: Agent Starting Run: r8j83iab with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 1
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 3791.691
Bits per dimension: 6.9773648791767835
Epoch: 0, train nll=3798.048095703125, val nll=3791.691182291667
saved!
Average test LL: 3791.842
Bits per dimension: 6.977641748425968
FINAL LOSS: nll=3791.841640625


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,6.97764


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: rv0rvoh8 with config:
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.1
[34m[1mwandb[0m: 	lr: 0.1
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 945.279
Bits per dimension: 1.739476295567498
Epoch: 0, train nll=975.5731811523438, val nll=945.2790625
saved!
Average test LL: 947.771
Bits per dimension: 1.7440611055641877
FINAL LOSS: nll=947.770573828125


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.74406


[34m[1mwandb[0m: Agent Starting Run: e0i1c51u with config:
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.1
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 1862.467
Bits per dimension: 3.4272598477365395
Epoch: 0, train nll=1880.457763671875, val nll=1862.4668723958334
saved!
Average test LL: 1864.549
Bits per dimension: 3.4310913750777194
FINAL LOSS: nll=1864.549029296875


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,3.43109


[34m[1mwandb[0m: Agent Starting Run: zythikhj with config:
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.1
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 4080.765
Bits per dimension: 7.50930971925546
Epoch: 0, train nll=4083.606689453125, val nll=4080.764578125
saved!
Average test LL: 4080.957
Bits per dimension: 7.509663631820312
FINAL LOSS: nll=4080.95690390625


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,7.50966


[34m[1mwandb[0m: Agent Starting Run: e4i6twat with config:
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.5
[34m[1mwandb[0m: 	lr: 0.1
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


Average test LL: 947.009
Bits per dimension: 1.7426589836856603
Epoch: 0, train nll=943.08056640625, val nll=947.0086223958333
saved!
Average test LL: 950.055
Bits per dimension: 1.7482650049376698
FINAL LOSS: nll=950.055087890625


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_bpd,▁

0,1
test_bpd,1.74827


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: hu6avx3y with config:
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	input_dim: 784
[34m[1mwandb[0m: 	lam: 0.5
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	max_patience: 30
[34m[1mwandb[0m: 	num_epochs: 1
[34m[1mwandb[0m: 	num_input_units: 8
[34m[1mwandb[0m: 	num_sum_units: 8


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


In [None]:
# TODO: Convert hyperparameter opt to .py, Figure out how to parallelize, Run on Eddie
# TODO: Meantime - write overleaf, figure out how to evaluate FID, sampling from PC