In [9]:
import os, sys

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
os.environ['PYTHONPATH'] = project_root
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("PYTHONPATH manually set to:", os.environ['PYTHONPATH'])

PYTHONPATH manually set to: c:\Users\ndhaj\Desktop\GraphicalPR


In [10]:
import numpy as np
from core.metrics import pmse
from core.linalg_utils import circular_aperture, random_unitary_matrix
from graph.structure.graph import Graph
from graph.prior.support_prior import SupportPrior
from graph.propagator.unitary_propagator import UnitaryPropagator
from graph.measurement.amplitude_measurement import AmplitudeMeasurement

# ==== 1. パラメータ ====
n = 128
shape = (3 * n,)
rng = np.random.default_rng(seed=123)
var = 1e-4
U = random_unitary_matrix(3 * n, rng=rng)

# Support mask: center half active
support = np.zeros(shape, dtype=bool)
support[n:2*n] = True

# ==== 2. グラフ定義 ====
class SimplePhaseGraph(Graph):
    def __init__(self):
        super().__init__()

        self.X = SupportPrior(support=support)

        self.Y = UnitaryPropagator(U) @ self.X

        self.Z = AmplitudeMeasurement(self.Y, var=var, damping=0.1)

        self.compile()

# ==== 3. 初期化とデータ生成 ====
g = SimplePhaseGraph()
g.set_init_rng(np.random.default_rng(seed=1))

g.generate_sample(rng=np.random.default_rng(seed=999))
true_x = g.X.get_sample()

# 観測データを measurement に設定
g.Z.update_observed_from_sample()

# ==== 4. 推論（Belief Propagation） ====
def monitor(graph, t):
    if t % 10 == 0:
        est = graph.X.compute_belief().data
        err = pmse(est, true_x)
        print(f"[t={t}] PMSE = {err:.3e}")

g.run(n_iter=100, callback=monitor)

# ==== 5. 結果 ====
final_est = g.X.compute_belief().data
final_err = pmse(final_est, true_x)
print(f"\nFinal PMSE after 30 iterations: {final_err:.3e}")

[t=0] PMSE = 4.714e-01
[t=10] PMSE = 4.236e-01
[t=20] PMSE = 3.191e-01
[t=30] PMSE = 2.684e-01
[t=40] PMSE = 2.155e-01
[t=50] PMSE = 8.682e-02
[t=60] PMSE = 4.601e-03
[t=70] PMSE = 3.256e-04
[t=80] PMSE = 2.732e-04
[t=90] PMSE = 2.735e-04

Final PMSE after 30 iterations: 2.736e-04


In [11]:
%prun -l 40 -s cumulative g.run(n_iter=100, callback=monitor)

[t=0] PMSE = 2.736e-04
[t=10] PMSE = 2.736e-04
[t=20] PMSE = 2.736e-04
[t=30] PMSE = 2.736e-04
[t=40] PMSE = 2.736e-04
[t=50] PMSE = 2.736e-04
[t=60] PMSE = 2.736e-04
[t=70] PMSE = 2.736e-04
[t=80] PMSE = 2.736e-04
[t=90] PMSE = 2.736e-04
 

         20831 function calls in 0.339 seconds

   Ordered by: cumulative time
   List reduced from 105 to 40 due to restriction <40>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.339    0.339 {built-in method builtins.exec}
        1    0.000    0.000    0.339    0.339 <string>:1(<module>)
        1    0.000    0.000    0.339    0.339 graph.py:58(run)
      100    0.001    0.000    0.324    0.003 graph.py:53(backward)
      100    0.001    0.000    0.279    0.003 unitary_propagator.py:91(backward)
      100    0.245    0.002    0.262    0.003 unitary_propagator.py:29(compute_belief)
      100    0.001    0.000    0.042    0.000 base.py:50(backward)
      100    0.001    0.000    0.041    0.000 amplitude_measurement.py:77(_compute_message)
      300    0.015    0.000    0.024    0.000 uncertain_array.py:121(__truediv__)
      200    0.006    0.000    0.022    0.000 linalg_utils.py:4(reduce_precision_to_scalar)
      100    0.0