In [1]:
import torch
from tqdm import tqdm

import gfn

from gfn.gflownet import TBGFlowNet  # We use a GFlowNet with the Trajectory Balance (TB) loss
from gfn.gym import HyperGrid  # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet  # NeuralNet is a simple multi-layer perceptron (MLP)

In [2]:
# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01)  # Grid of size 8x8x8x8

# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = NeuralNet(
    input_dim=env.preprocessor.output_dim,
    output_dim=env.n_actions
)  # Neural network for the forward policy, with as many outputs as there are actions
module_PB = NeuralNet(
    input_dim=env.preprocessor.output_dim,
    output_dim=env.n_actions - 1,
    torso=module_PF.torso  # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

# 4 - We define the GFlowNet.
gfn = TBGFlowNet(off_policy=False,init_logZ=0., pf=pf_estimator, pb=pb_estimator)  # We initialize logZ to 0

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator)  # We use an on-policy sampler, based on the forward policy

# Policy parameters have their own LR.
non_logz_params = [v for k, v in dict(gfn.named_parameters()).items() if k != "logZ"]
optimizer = torch.optim.Adam(non_logz_params, lr=1e-3)

# Log Z gets dedicated learning rate (typically higher).
logz_params = [dict(gfn.named_parameters())["logZ"]]
optimizer.add_param_group({"params": logz_params, "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
    trajectories = sampler.sample_trajectories(env=env, off_policy=False, n_trajectories=16)
    optimizer.zero_grad()
    loss = gfn.loss(env, trajectories)
    loss.backward()
    optimizer.step()
    if i % 25 == 0:
        pbar.set_postfix({"loss": loss.item()})


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:58<00:00, 17.14it/s, loss=0.123]


In [6]:
sampler.sample_trajectories(env=env, off_policy=False, n_trajectories=20)

Trajectories(n_trajectories=20, max_length=27, First 10 trajectories:states=
[0 0 0 0]-> [0 0 0 1]-> [0 0 1 1]-> [0 0 1 2]-> [0 0 2 2]-> [1 0 2 2]-> [1 0 2 3]-> [2 0 2 3]-> [2 0 2 4]-> [2 0 3 4]-> [3 0 3 4]-> [3 0 4 4]-> [3 0 5 4]-> [3 1 5 4]-> [3 1 6 4]-> [3 1 7 4]-> [3 2 7 4]-> [3 2 7 5]-> [-1 -1 -1 -1]
[0 0 0 0]-> [0 0 0 1]-> [0 0 0 2]-> [0 0 0 3]-> [0 0 0 4]-> [0 0 0 5]-> [0 0 0 6]-> [0 0 1 6]-> [0 0 1 7]-> [0 0 2 7]-> [0 1 2 7]-> [0 1 3 7]-> [0 1 4 7]-> [0 1 5 7]-> [0 2 5 7]-> [0 3 5 7]-> [0 4 5 7]-> [1 4 5 7]-> [2 4 5 7]-> [2 5 5 7]-> [-1 -1 -1 -1]
[0 0 0 0]-> [1 0 0 0]-> [1 0 1 0]-> [1 0 2 0]-> [1 0 3 0]-> [2 0 3 0]-> [3 0 3 0]-> [3 0 4 0]-> [4 0 4 0]-> [5 0 4 0]-> [5 1 4 0]-> [5 1 5 0]-> [6 1 5 0]-> [7 1 5 0]-> [7 1 6 0]-> [-1 -1 -1 -1]
[0 0 0 0]-> [1 0 0 0]-> [1 1 0 0]-> [2 1 0 0]-> [3 1 0 0]-> [4 1 0 0]-> [5 1 0 0]-> [6 1 0 0]-> [7 1 0 0]-> [7 2 0 0]-> [7 3 0 0]-> [7 4 0 0]-> [7 4 0 1]-> [7 5 0 1]-> [7 6 0 1]-> [-1 -1 -1 -1]
[0 0 0 0]-> [0 0 1 0]-> [0 0 2 0]-> [0 0 3 0]-> [0 