In [4]:
import torch

import tensorly as tl
from tensorly.cp_tensor import cp_to_tensor

from src.environments import MarkovChainSimulator
from src.utils import ProbabilityTensorEstimator, reconstruct_transition_tensor

In [5]:
tl.set_backend('pytorch')

# 1) Example random tensor

In [9]:
d1, d2, d3 = 10, 10, 10
transition_probs = torch.rand(d1, d2, d3, d1, d2, d3)
transition_probs = transition_probs / transition_probs.sum(dim=(-3, -2, -1), keepdim=True)

mc_simulator = MarkovChainSimulator(transition_probs)

num_trajectories = 100
num_steps = 100
trajectories = [mc_simulator.simulate(num_steps) for _ in range(num_trajectories)]

estimator = ProbabilityTensorEstimator((d1, d2, d3))
estimator.fit(trajectories)
transition_tensor = estimator.get_transition_tensor()
marginal_tensor = estimator.get_marginal_origin_state()
joint_tensor = estimator.get_joint_distribution()

recomputed_transition_tensor = reconstruct_transition_tensor(joint_tensor, marginal_tensor, (d1, d2, d3))

assert torch.allclose(transition_tensor, recomputed_transition_tensor, atol=1e-6)

# 2) Example low rank tensor

In [11]:
def generate_probability_matrix(dims, K):
    matrix = torch.rand(dims, K)
    matrix = matrix / matrix.sum(dim=0, keepdim=True)
    return matrix

def generate_probability_vector(D):
    vector = torch.rand(D)
    vector = vector / vector.sum()
    return vector

def generate_tensor(D, K, d):
    factors = [generate_probability_matrix(d, K) for _ in range(D)]
    weights = generate_probability_vector(K)
    return factors, weights

In [19]:
D = 4
K = 3
d = 10
N = 10_000

factors, weights = generate_tensor(2 * D, K, d)
P_joint = cp_to_tensor((weights, factors))
P_marginal = cp_to_tensor((weights[:D], factors[:D]))
P_conditional = reconstruct_transition_tensor(P_joint, P_marginal, (d, d, d, d))
mc_simulator = MarkovChainSimulator(P_conditional)

num_trajectories = 100
num_steps = 100
trajectories = [mc_simulator.simulate(num_steps) for _ in range(num_trajectories)]

estimator = ProbabilityTensorEstimator((d, d, d, d))
estimator.fit(trajectories)
P_conditional_sampled = estimator.get_transition_tensor()
P_marginal_sampled  = estimator.get_marginal_origin_state()
P_joint_sampled  = estimator.get_joint_distribution()