In [4]:
import os, sys

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
os.environ['PYTHONPATH'] = project_root
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("PYTHONPATH manually set to:", os.environ['PYTHONPATH'])

PYTHONPATH manually set to: c:\Users\ndhaj\Desktop


In [None]:
import numpy as np
from gpie import Graph, SupportPrior, fft2, PhaseMaskPropagator, AmplitudeMeasurement, pmse
from gpie.core.linalg_utils import circular_aperture, random_phase_mask

# ==== 1. パラメータ ====
H, W = 128, 128
shape = (H, W)
var = 1e-4
support_radius = 0.2
rng = np.random.default_rng(seed=42)
support = circular_aperture(shape, radius=support_radius)

class StructuredRandomGraph(Graph):
    def __init__(self, n_layers=1):
        super().__init__()
        # Support Prior
        x = ~SupportPrior(support=support, label = "object")
        # Random Structured Matrix
        for i in range(n_layers):
            phase = random_phase_mask(shape, rng=rng)
            x = fft2(PhaseMaskPropagator(phase) @ x)
        with self.observe():
            z = AmplitudeMeasurement(var=var, damping=0.3) @ x
        self.compile()

# ==== 4. 初期化と推論 ====
g = StructuredRandomGraph(n_layers=3)
g.set_init_rng(np.random.default_rng(seed=1))
g.generate_sample(rng=np.random.default_rng(seed=999), update_observed=True)
X = g.get_wave("object")
true_x = X.get_sample()

def monitor(graph, t):
    X = graph.get_wave("object")
    est = X.compute_belief().data
    err = pmse(est, true_x)
    if t % 50 == 0:
        print(f"[t={t}] PMSE = {err:.5e}")

g.run(n_iter=300, callback=monitor, verbose = False)

[t=0] PMSE = 2.47071e-01
[t=50] PMSE = 1.00108e-03
[t=100] PMSE = 5.87098e-05
[t=150] PMSE = 5.58564e-05
[t=200] PMSE = 5.58644e-05
[t=250] PMSE = 5.58638e-05
