# Region Graphs

This notebook shows how to create a region graph for the MNIST dataset and train the resulting circuit.
We start be defining the variable for the class.

In [None]:
import jax
from random_events.set import SetElement

include_variables = 20

class Digit(SetElement):
    EMPTY_SET = -1
    ZERO = 0
    ONE = 1
    TWO = 2
    THREE = 3
    FOUR = 4
    FIVE = 5
    SIX = 6
    SEVEN = 7
    EIGHT = 8
    NINE = 9

Then, we add the variables for the features.

In [None]:
from sortedcontainers import SortedSet
from random_events.variable import Symbolic, Continuous

variables = SortedSet([Symbolic("Digit", Digit)] + [Continuous(f"Pixel_{i}_{j}") for i in range(8) for j in range(8)])
variables = variables[:include_variables]

Next, we load the dataset.

In [None]:
from sklearn import datasets
import numpy as np
from sklearn.preprocessing import MinMaxScaler
digits = datasets.load_digits(as_frame=False)
x = digits.data
y = digits.target
data = np.concatenate((y.reshape(-1, 1), x), axis=1)[:, :include_variables]
data[:, 1:] = MinMaxScaler((-1, 1)).fit_transform(data[:, 1:])

Now, we construct a random region graph.

In [None]:
from probabilistic_model.learning.region_graph.region_graph import RegionGraph

region_graph = RegionGraph(variables, repetitions=6, depth=3, partitions=2)
region_graph = region_graph.create_random_region_graph()
model = region_graph.as_probabilistic_circuit(input_units=16, sum_units=5)

Let's have a look at the structure of the resulting circuit.

In [None]:
from matplotlib import pyplot as plt

nx_model = model.to_nx()
nx_model.plot_structure()
plt.show()

In [None]:
print(model.root.number_of_trainable_parameters)

In [None]:
import optax

model.fit(data, epochs=100, optimizer=optax.adam(1e-3))



In [None]:
nx_model = model.to_nx()