In [1]:
import pyjuice as juice
import torch
import torchvision
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

    
def main():
    device = torch.device("cuda:0")

    train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True)
    test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True)

    train_data = train_dataset.data.reshape(60000, 28*28)
    test_data = test_dataset.data.reshape(10000, 28*28)

    num_features = train_data.size(1)

    train_loader = DataLoader(
        dataset = TensorDataset(train_data),
        batch_size = 512,
        shuffle = True,
        drop_last = True
    )
    test_loader = DataLoader(
        dataset = TensorDataset(test_data),
        batch_size = 512,
        shuffle = False,
        drop_last = True
    )

    ns = juice.structures.HCLT(
        train_data.float().to(device), 
        num_bins = 32, 
        sigma = 0.5 / 32, 
        num_latents = 128, 
        chunk_size = 32
    )
    ns.init_parameters(perturbation = 2.0)
    pc = juice.TensorCircuit(ns)

    pc.to(device)

    samples = juice.queries.sample(pc, num_samples = 500)

    print(samples)
    print(samples.shape)

    # import pdb; pdb.set_trace()


if __name__ == "__main__":
    main()

Compiling 51 TensorCircuit layers...


100%|██████████| 51/51 [00:08<00:00,  5.86it/s]


tensor([[ 14, 124,  58,  ...,  64, 194, 113],
        [223,  89,  86,  ...,  82, 110, 164],
        [209,  88, 186,  ...,  28,   7, 194],
        ...,
        [ 11, 193, 210,  ..., 103, 232,  26],
        [251, 255,  45,  ..., 107, 210, 233],
        [229, 212, 216,  ..., 175,  35, 231]], device='cuda:0')
torch.Size([500, 784])
