In [1]:
import sys
import numpy as np
from pathlib import Path
import torchopt
import posteriors

# Add paths for importing utilities and models
current_dir = Path.cwd()
sys.path.append(str(current_dir))
sys.path.append(str(current_dir.parent))
sys.path.append(str(current_dir.parent / "baselines"))


from src.nanogpt_utils import load_model, load_tokenizer, encode, decode
from src.bayesian_utils import create_training_batches, run_bayesian_pipeline
from config import CONFIG, MODEL_PATH, META_PATH, DATA_DIR

  from optree.integration.torch import tree_ravel


Model arguments: {'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'block_size': 256, 'bias': False, 'vocab_size': 65, 'dropout': 0.2}
number of parameters: 10.65M
Model loaded successfully!
Number of parameters: 10,745,088


#### Load pre-trained baseline model

In [2]:
model, checkpoint = load_model(Path(MODEL_PATH))
    
# Load tokenizer and vocabulary size
stoi, itos = load_tokenizer(Path(META_PATH))
vocab_size = len(itos)

# Extract model parameters for posteriors
params = dict(model.named_parameters())

Model arguments: {'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'block_size': 256, 'bias': False, 'vocab_size': 65, 'dropout': 0.2}
number of parameters: 10.65M
Model loaded successfully!
Number of parameters: 10,745,088


#### Preparing training data

In [3]:
# Prepare training data for Bayesian inference
train_data_path = Path(DATA_DIR / 'train.bin')
data = np.memmap(str(train_data_path), dtype=np.uint16, mode='r')

training_batches = create_training_batches(
    data, 
    CONFIG['batch_size'], 
    CONFIG['max_seq_length'], 
    CONFIG['train_samples']
)

print(f"Created {len(training_batches)} training batches")
print(f"Batch shape: {training_batches[0][0].shape}")
print(f"Target shape: {training_batches[0][1].shape}")

# Calculate number of data points for posteriors
num_data = CONFIG['train_samples']
print(f"Total training samples: {num_data}")
    

Created 32 training batches
Batch shape: torch.Size([16, 128])
Target shape: torch.Size([16, 1])
Total training samples: 500


  x_tensor = torch.tensor(x_batch, dtype=torch.long, device=DEVICE)


#### Setup Variational Inference

In [4]:
state_vi, metrics_vi, eval_vi = run_bayesian_pipeline(training_batches, 'vi')

[34m[1mwandb[0m: Currently logged in as: [33mssophi-nikol[0m ([33mssophi-nikol-tum-ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✓ W&B initialized: https://wandb.ai/ssophi-nikol-tum-ai/bayesian-nanogpt/runs/77srvth6

Setting up VI sampler
✓ VI configured with:
  - Learning rate: 5e-06
  - Temperature: 0.001
  - Samples per update: 1

Starting Bayesian Training with VI
Configuration:
  - Epochs: 3
  - Batches per epoch: 32
  - Total iterations: 96

Epoch 1/3
------------------------------------------------------------
NLL computed successfully: 62.042274475097656
Log prior computed successfully: -15255827.0
Log posterior computed successfully: -30573.697265625
NLL computed successfully: 0.9336455464363098
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.126953125
NLL computed successfully: 58.5260009765625
Log prior computed successfully: -15254245.0
Log posterior computed successfully: -30567.015625
NLL computed successfully: 1.0173876285552979
Log prior computed successfully: -9882596.0
Log posterior computed successfully: -19766.208984375
NLL computed successfully: 65.170

wandb-core(34848) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.0370897054672241
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.23046875
NLL computed successfully: 72.14247131347656
Log prior computed successfully: -15254542.0
Log posterior computed successfully: -30581.2265625


wandb-core(34884) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.9287899732589722
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.123046875
NLL computed successfully: 72.15858459472656
Log prior computed successfully: -15251273.0
Log posterior computed successfully: -30574.705078125
NLL computed successfully: 1.748432993888855
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.94140625


wandb-core(34931) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 73.05329895019531
Log prior computed successfully: -15251471.0
Log posterior computed successfully: -30575.994140625
NLL computed successfully: 1.5459648370742798
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.740234375
NLL computed successfully: 84.74433135986328
Log prior computed successfully: -15251775.0
Log posterior computed successfully: -30588.294921875


wandb-core(34963) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.9026210904121399
Log prior computed successfully: -9882596.0
Log posterior computed successfully: -19766.09375
  [ 30/32] Loss: 1.3149 | Log Post: -19766.4258
NLL computed successfully: 64.98748779296875
Log prior computed successfully: -15261370.0
Log posterior computed successfully: -30587.728515625


wandb-core(34986) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.0274252891540527
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.220703125
NLL computed successfully: 85.10589599609375
Log prior computed successfully: -15250609.0
Log posterior computed successfully: -30586.32421875


wandb-core(35033) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.7205038070678711
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19765.9140625
  ✓ Epoch Complete:
    Loss: 1.2392 | Log Post: -19766.3904 | Time: 245.04s

Epoch 2/3
------------------------------------------------------------
NLL computed successfully: 75.9001693725586
Log prior computed successfully: -15251218.0
Log posterior computed successfully: -30578.3359375


wandb-core(35070) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.0808196067810059
Log prior computed successfully: -9882596.0
Log posterior computed successfully: -19766.271484375
NLL computed successfully: 67.10269927978516
Log prior computed successfully: -15254685.0
Log posterior computed successfully: -30576.47265625
NLL computed successfully: 1.1044461727142334
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.296875
NLL computed successfully: 64.61165618896484
Log prior computed successfully: -15256072.0
Log posterior computed successfully: -30576.755859375
NLL computed successfully: 2.1443278789520264
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19767.337890625
NLL computed successfully: 61.96792984008789
Log prior computed successfully: -15256933.0
Log posterior computed successfully: -30575.833984375


wandb-core(35110) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.6193203926086426
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.8125
NLL computed successfully: 70.65680694580078
Log prior computed successfully: -15257719.0
Log posterior computed successfully: -30586.09375
NLL computed successfully: 1.6939077377319336
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.88671875
NLL computed successfully: 75.22013092041016
Log prior computed successfully: -15256784.0
Log posterior computed successfully: -30588.7890625
NLL computed successfully: 1.8016141653060913
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.994140625
  [  6/32] Loss: 1.6478 | Log Post: -19766.8656


wandb-core(35126) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 63.21418762207031
Log prior computed successfully: -15254910.0
Log posterior computed successfully: -30573.03515625
NLL computed successfully: 1.0721336603164673
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.265625
NLL computed successfully: 74.55235290527344
Log prior computed successfully: -15255206.0
Log posterior computed successfully: -30584.96484375
NLL computed successfully: 0.785775363445282
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19765.978515625
NLL computed successfully: 50.98648452758789
Log prior computed successfully: -15257997.0
Log posterior computed successfully: -30566.98046875


wandb-core(35158) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.5339596271514893
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.7265625
NLL computed successfully: 74.6753921508789
Log prior computed successfully: -15253374.0
Log posterior computed successfully: -30581.423828125


wandb-core(35200) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.9997040629386902
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.193359375
NLL computed successfully: 51.18925094604492
Log prior computed successfully: -15253794.0
Log posterior computed successfully: -30558.77734375


wandb-core(35244) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.16982901096344
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.36328125
NLL computed successfully: 70.37142944335938
Log prior computed successfully: -15252376.0
Log posterior computed successfully: -30575.123046875
NLL computed successfully: 1.2224189043045044
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.416015625
  [ 12/32] Loss: 1.1365 | Log Post: -19766.3355
NLL computed successfully: 61.085941314697266
Log prior computed successfully: -15256661.0
Log posterior computed successfully: -30574.408203125


wandb-core(35269) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.358454704284668
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.552734375
NLL computed successfully: 81.8223648071289
Log prior computed successfully: -15257303.0
Log posterior computed successfully: -30596.427734375
NLL computed successfully: 1.7152122259140015
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.908203125
NLL computed successfully: 81.34150695800781
Log prior computed successfully: -15253277.0
Log posterior computed successfully: -30587.896484375
NLL computed successfully: 0.9167090654373169
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.109375
NLL computed successfully: 92.7585678100586
Log prior computed successfully: -15255989.0
Log posterior computed successfully: -30604.736328125


wandb-core(35309) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.1454998254776
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.337890625
NLL computed successfully: 74.09768676757812
Log prior computed successfully: -15253818.0
Log posterior computed successfully: -30581.734375
NLL computed successfully: 0.7582829594612122
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19765.951171875
NLL computed successfully: 79.14362335205078
Log prior computed successfully: -15259500.0
Log posterior computed successfully: -30598.14453125
NLL computed successfully: 1.380035161972046
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.57421875
  [ 18/32] Loss: 1.2138 | Log Post: -19766.3762
NLL computed successfully: 67.61595153808594
Log prior computed successfully: -15256233.0
Log posterior computed successfully: -30580.08203125
NLL computed successfully: 1.005818486213684
Log prior computed successfully: -9882597.0
Log poste

wandb-core(35334) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.9995135068893433
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.193359375
NLL computed successfully: 74.50971221923828
Log prior computed successfully: -15257263.0
Log posterior computed successfully: -30589.03515625
NLL computed successfully: 0.7760564088821411
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19765.96875
NLL computed successfully: 75.22119140625
Log prior computed successfully: -15253700.0
Log posterior computed successfully: -30582.62109375
NLL computed successfully: 0.7494469285011292
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19765.943359375
NLL computed successfully: 75.26235961914062
Log prior computed successfully: -15252231.0
Log posterior computed successfully: -30579.724609375


wandb-core(35389) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.4449684619903564
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.638671875
NLL computed successfully: 57.9121208190918
Log prior computed successfully: -15253052.0
Log posterior computed successfully: -30564.015625
NLL computed successfully: 1.1306871175765991
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.32421875
  [ 24/32] Loss: 0.9941 | Log Post: -19766.2137
NLL computed successfully: 45.87874221801758
Log prior computed successfully: -15253258.0
Log posterior computed successfully: -30552.39453125
NLL computed successfully: 1.1819844245910645
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.375
NLL computed successfully: 68.13577270507812
Log prior computed successfully: -15258471.0
Log posterior computed successfully: -30585.078125


wandb-core(35408) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.0358941555023193
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.228515625
NLL computed successfully: 61.363197326660156
Log prior computed successfully: -15255101.0
Log posterior computed successfully: -30571.564453125
NLL computed successfully: 0.9576789140701294
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.150390625
NLL computed successfully: 66.87063598632812
Log prior computed successfully: -15258404.0
Log posterior computed successfully: -30583.6796875
NLL computed successfully: 1.7097022533416748
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.90234375
NLL computed successfully: 73.54645538330078
Log prior computed successfully: -15251633.0
Log posterior computed successfully: -30576.8125
NLL computed successfully: 1.407309651374817
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.6015625


wandb-core(35450) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 79.47528839111328
Log prior computed successfully: -15255183.0
Log posterior computed successfully: -30589.83984375
NLL computed successfully: 1.0055198669433594
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.19921875
  [ 30/32] Loss: 1.2531 | Log Post: -19766.4164
NLL computed successfully: 61.7320556640625
Log prior computed successfully: -15251533.0
Log posterior computed successfully: -30564.798828125
NLL computed successfully: 0.9667180180549622
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.16015625
NLL computed successfully: 80.45545196533203
Log prior computed successfully: -15255876.0
Log posterior computed successfully: -30592.20703125
NLL computed successfully: 0.5991674065589905
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19765.79296875
  ✓ Epoch Complete:
    Loss: 1.2163 | Log Post: -19766.3954 | Time: 185.04s

Epoch 3/3
------

wandb-core(35464) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 65.19196319580078
Log prior computed successfully: -15255369.0
Log posterior computed successfully: -30575.9296875
NLL computed successfully: 1.0109922885894775
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.205078125
NLL computed successfully: 79.09197998046875
Log prior computed successfully: -15253409.0
Log posterior computed successfully: -30585.91015625


wandb-core(35536) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 2.1679162979125977
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19767.361328125
NLL computed successfully: 64.98943328857422
Log prior computed successfully: -15257201.0
Log posterior computed successfully: -30579.392578125


wandb-core(35563) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.6131513118743896
Log prior computed successfully: -9882596.0
Log posterior computed successfully: -19766.8046875
NLL computed successfully: 76.13050842285156
Log prior computed successfully: -15255642.0
Log posterior computed successfully: -30587.4140625
NLL computed successfully: 1.7149986028671265
Log prior computed successfully: -9882596.0
Log posterior computed successfully: -19766.90625
NLL computed successfully: 83.43498992919922
Log prior computed successfully: -15254188.0
Log posterior computed successfully: -30591.8125


wandb-core(35604) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.6947252750396729
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.888671875
  [  6/32] Loss: 1.6672 | Log Post: -19766.8332
NLL computed successfully: 67.42707824707031
Log prior computed successfully: -15257154.0
Log posterior computed successfully: -30581.736328125
NLL computed successfully: 1.2426868677139282
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.435546875


wandb-core(35632) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 66.14804077148438
Log prior computed successfully: -15255184.0
Log posterior computed successfully: -30576.515625


wandb-core(35676) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.8988552093505859
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.091796875
NLL computed successfully: 59.701255798339844
Log prior computed successfully: -15253192.0
Log posterior computed successfully: -30566.0859375
NLL computed successfully: 1.4619450569152832
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.65625


wandb-core(35727) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 72.80996704101562
Log prior computed successfully: -15256352.0
Log posterior computed successfully: -30585.513671875
NLL computed successfully: 1.07167387008667
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.265625
NLL computed successfully: 69.14420318603516
Log prior computed successfully: -15252881.0
Log posterior computed successfully: -30574.90625
NLL computed successfully: 1.237592101097107
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.431640625


wandb-core(35769) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 74.4719467163086
Log prior computed successfully: -15255405.0
Log posterior computed successfully: -30585.283203125
NLL computed successfully: 1.0808414220809937
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.2734375
  [ 12/32] Loss: 1.1848 | Log Post: -19766.3438
NLL computed successfully: 47.45212173461914
Log prior computed successfully: -15254160.0
Log posterior computed successfully: -30555.771484375
NLL computed successfully: 1.5070172548294067
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.701171875


wandb-core(35794) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 78.50423431396484
Log prior computed successfully: -15253459.0
Log posterior computed successfully: -30585.421875
NLL computed successfully: 1.6842328310012817
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.876953125
NLL computed successfully: 73.43408966064453
Log prior computed successfully: -15250014.0
Log posterior computed successfully: -30573.4609375
NLL computed successfully: 1.0029313564300537
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.197265625


wandb-core(35829) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 60.74522018432617
Log prior computed successfully: -15254543.0
Log posterior computed successfully: -30569.83203125
NLL computed successfully: 1.302565097808838
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.49609375
NLL computed successfully: 71.66082000732422
Log prior computed successfully: -15253747.0
Log posterior computed successfully: -30579.154296875


wandb-core(35871) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.8103242516517639
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.00390625
NLL computed successfully: 71.21004486083984
Log prior computed successfully: -15254346.0
Log posterior computed successfully: -30579.90234375
NLL computed successfully: 1.3454995155334473
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.5390625
  [ 18/32] Loss: 1.1668 | Log Post: -19766.4227
NLL computed successfully: 56.1676025390625
Log prior computed successfully: -15252970.0
Log posterior computed successfully: -30562.107421875


wandb-core(35925) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 1.0712834596633911
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.263671875
NLL computed successfully: 83.09795379638672
Log prior computed successfully: -15253262.0
Log posterior computed successfully: -30589.62109375
NLL computed successfully: 1.194716215133667
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.388671875
NLL computed successfully: 83.68022918701172
Log prior computed successfully: -15251201.0
Log posterior computed successfully: -30586.08203125


wandb-core(35953) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.9190771579742432
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.11328125
NLL computed successfully: 65.40257263183594
Log prior computed successfully: -15253064.0
Log posterior computed successfully: -30571.53125
NLL computed successfully: 0.7222734093666077
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19765.916015625
NLL computed successfully: 86.01776123046875
Log prior computed successfully: -15258938.0
Log posterior computed successfully: -30603.89453125
NLL computed successfully: 1.240797519683838
Log prior computed successfully: -9882596.0
Log posterior computed successfully: -19766.431640625


wandb-core(35979) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 57.33997344970703
Log prior computed successfully: -15252050.0
Log posterior computed successfully: -30561.439453125
NLL computed successfully: 1.0940619707107544
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.287109375
  [ 24/32] Loss: 0.9886 | Log Post: -19766.2273
NLL computed successfully: 84.73593139648438
Log prior computed successfully: -15256317.0
Log posterior computed successfully: -30597.37109375
NLL computed successfully: 1.0721856355667114
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.265625
NLL computed successfully: 74.73538208007812
Log prior computed successfully: -15255039.0
Log posterior computed successfully: -30584.814453125
NLL computed successfully: 1.0814028978347778
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.275390625
NLL computed successfully: 62.78050994873047
Log prior computed successfully: -15250434.0
Lo

wandb-core(35997) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.9859414100646973
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.1796875
NLL computed successfully: 80.75811767578125
Log prior computed successfully: -15252901.0
Log posterior computed successfully: -30586.560546875
NLL computed successfully: 1.81009840965271
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19767.00390625
NLL computed successfully: 77.01741027832031
Log prior computed successfully: -15251786.0
Log posterior computed successfully: -30580.58984375
NLL computed successfully: 1.4407604932785034
Log prior computed successfully: -9882596.0
Log posterior computed successfully: -19766.6328125
NLL computed successfully: 91.06924438476562
Log prior computed successfully: -15254015.0
Log posterior computed successfully: -30599.09765625


wandb-core(36028) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


NLL computed successfully: 0.9465652108192444
Log prior computed successfully: -9882596.0
Log posterior computed successfully: -19766.138671875
  [ 30/32] Loss: 1.2504 | Log Post: -19766.4461
NLL computed successfully: 62.178382873535156
Log prior computed successfully: -15251051.0
Log posterior computed successfully: -30564.279296875
NLL computed successfully: 0.9634807705879211
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19766.15625
NLL computed successfully: 62.208953857421875
Log prior computed successfully: -15256130.0
Log posterior computed successfully: -30574.46875
NLL computed successfully: 0.5427213907241821
Log prior computed successfully: -9882597.0
Log posterior computed successfully: -19765.736328125
  ✓ Epoch Complete:
    Loss: 1.2237 | Log Post: -19766.4084 | Time: 237.77s

Training Complete!
  Final Loss: 1.2237
  Final Log Posterior: -19766.4084
  Total Time: 667.89s


Evaluating VI Predictions

0. Deterministic Model (Original):

Python(36044) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


✓ Model logged to W&B as artifact



0,1
batch_log_posterior,▆▆▃▄▃▄▅▆▇▇▆▆▆▃▇▄▃▅▇▄▆▇▅▃▆▁▃▃▅▆▆▆▆▆▇▆▆▆▃█
batch_loss,▃▂█▅▆▃▃▃▄▂▄▃▃▂▅▂▁▇▅▆▄▆▃▄▃▂▂▃▂▆▆▃▅▅▂▁▂▂▂▃
epoch,▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅█████████████
epoch_epoch,▁▅█
epoch_log_posterior,█▆▁
epoch_loss,█▁▃
epoch_time,█▁▇
eval/avg_predictive_entropy,▁
eval/deterministic_loss,▁
eval/deterministic_perplexity,▁

0,1
batch_log_posterior,-19765.73633
batch_loss,1.10841
epoch,3
epoch_epoch,2
epoch_log_posterior,-19766.40845
epoch_loss,1.2237
epoch_time,237.77182
eval/avg_predictive_entropy,1.70291
eval/deterministic_loss,1.01713
eval/deterministic_perplexity,2.76524


✓ W&B run finished


#### Run Bayesian Training

#### 8. Compare Deterministic vs Bayesian Predictions