In [4]:
import sys, os
sys.path.insert(0, os.path.abspath(".."))

import numpy as np
from core.uncertain_array import UncertainArray
from graph.wave import Wave


# Dummy factor class for testing
class DummyFactor:
    def __init__(self, name):
        self.name = name
        self.last_received_message = None

    def receive_message(self, wave, message):
        print(f"{self.name} received message from Wave: shape={message.shape}")
        self.last_received_message = message

# Create a Wave with shape (2, 2)
wave = Wave(shape=(2, 2), dtype=np.complex128)

# Create dummy parent and children
parent = DummyFactor("ParentFactor")
child1 = DummyFactor("ChildFactor1")
child2 = DummyFactor("ChildFactor2")

# Create UncertainArray messages
msg_from_parent = UncertainArray.random((2, 2), precision=2.0)
msg_from_child1 = UncertainArray.random((2, 2), precision=1.0)
msg_from_child2 = UncertainArray.random((2, 2), precision=1.5)

# Connect parent and children to the wave
wave.set_parent(parent)
wave.add_child(child1)
wave.add_child(child2)

In [5]:
# メッセージの受信
wave.receive_message(parent, msg_from_parent)
wave.receive_message(child1, msg_from_child1)
wave.receive_message(child2, msg_from_child2)

# belief を計算
belief = wave.compute_belief()

# 結果の確認
print("Computed belief:")
print("Shape:", belief.shape)
print("Precision (as array):")
print(belief.precision)

# 精度チェック：全体形状と数値妥当性
assert belief.shape == (2, 2), "Belief shape mismatch"
assert np.all(belief.precision > 0), "Belief precision must be positive"

# 単純な一致テスト： combine と同じ動作か
expected = UncertainArray.combine([msg_from_parent, msg_from_child1, msg_from_child2])
assert np.allclose(belief.data, expected.data), "Belief data mismatch"
assert np.allclose(belief.precision, expected.precision), "Belief precision mismatch"

print("Belief computation test passed.")


Computed belief:
Shape: (2, 2)
Precision (as array):
[[4.5 4.5]
 [4.5 4.5]]
Belief computation test passed.
