In [93]:
import pickle
import time
from typing import List
import datetime as dt

import numpy as np

import plotly.graph_objects as go
import plotly.express as px

from tqdm import tqdm

from quara.data_analysis import data_analysis, physicality_violation_check
from quara.objects.composite_system import CompositeSystem
from quara.objects.elemental_system import ElementalSystem
from quara.objects.matrix_basis import get_normalized_pauli_basis
from quara.objects.povm import (
    Povm,
    get_x_measurement,
    get_y_measurement,
    get_z_measurement,
)
from quara.objects.qoperation import QOperation
from quara.objects.state import State, get_z0_1q, get_z1_1q, get_x0_1q
from quara.protocol.qtomography.standard.standard_qst import StandardQst
from quara.protocol.qtomography.standard.linear_estimator import LinearEstimator
from quara.protocol.qtomography.standard.projected_linear_estimator import ProjectedLinearEstimator

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
def estimate_sequence(name: str, true_object: State, num_data: List[int], iterations: int, on_para_eq_constraint: bool=True):
    qst = StandardQst(povms, on_para_eq_constraint=on_para_eq_constraint)

    # generate empi dists and calc estimate
    obj_sequences = []
    start = time.time()
    for iteration in tqdm(range(iterations)):
        empi_dists_seq = qst.generate_empi_dists_sequence(
            true_object, num_data
        )

        estimator = LinearEstimator()
        #obj_sequence = estimator.calc_estimate_sequence_qoperation(qst, empi_dists_seq)
        obj_sequence = estimator.calc_estimate_sequence(qst, empi_dists_seq)

        # info = {
        #     "iteration": iteration,
        #     "empi_dists_seq": empi_dists_seq,
        #     "obj_sequence": obj_sequence.estimated_var_sequence,
        # }
        # print(info)
        obj_sequences.append(obj_sequence.estimated_qoperation_sequence)

    end = time.time()
    print(f"time(s)={end - start}")
    #with open(f"state_obj_sequences_{name}_{on_para_eq_constraint}_{iterations}.pkl", "wb") as f:
    #    pickle.dump(obj_sequences, f)

    #with open("obj_sequences.pkl", "rb") as f:
    #    obj_sequences = pickle.load(f)
    return obj_sequences

In [110]:
def estimate(name: str, true_object: State, num_data: List[int], iterations: int, on_para_eq_constraint: bool=True):
    qst = StandardQst(povms, on_para_eq_constraint=on_para_eq_constraint)

    # generate empi dists and calc estimate
    obj_sequences = []
    start = time.time()
    for iteration in tqdm(range(iterations)):
        empi_dists_seq = qst.generate_empi_dists_sequence(
            true_object, num_data
        )

        estimator = LinearEstimator()
        obj_sequence = estimator.calc_estimate_sequence(qst, empi_dists_seq)
        obj_sequences.append(obj_sequence.estimated_qoperation_sequence)

    end = time.time()
    print(f"time(s)={end - start}")
    
    
    identify = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
    path = f"qst_data/{name}_{identify}.pkl"
    with open(path, "wb") as f:
        pickle.dump(obj_sequences, f)
    
    return obj_sequences

# True Object = State([1/np.sqrt(2), 0, 0, 1/np.sqrt(2)])

In [153]:
# setup system
e_sys = ElementalSystem(0, get_normalized_pauli_basis())
c_sys = CompositeSystem([e_sys])

povm_x = get_x_measurement(c_sys)
povm_y = get_y_measurement(c_sys)
povm_z = get_z_measurement(c_sys)
povms = [povm_x, povm_y, povm_z]

In [154]:
true_object = get_z0_1q(c_sys)
num_data = [1000]
iterations = 1000

param_affine_est_linear = estimate("z0_affine", true_object, num_data, iterations, on_para_eq_constraint=True)
param_linear_est_linear = estimate("z0_linear", true_object, num_data, iterations, on_para_eq_constraint=False)

param_affine_est_linear = [p[0] for p in param_affine_est_linear]
param_linear_est_linear = [p[0] for p in param_linear_est_linear]

100%|██████████| 1000/1000 [00:11<00:00, 86.77it/s]
  1%|          | 6/1000 [00:00<00:16, 59.72it/s]

time(s)=11.529096126556396


100%|██████████| 1000/1000 [00:11<00:00, 90.57it/s]

time(s)=11.043689966201782





## on_para_eq_constraint = True

In [155]:
violation_result = physicality_violation_check.get_physicality_violation_result_for_state(param_affine_est_linear)
violation_result.keys()

dict_keys(['violation_list'])

In [156]:
physicality_violation_check.make_prob_dist_histogram(violation_result["violation_list"], bin_size=0.01)

## on_para_eq_constraint = False

In [157]:
violation_result = physicality_violation_check.get_physicality_violation_result_for_state(param_linear_est_linear)
violation_result.keys()

100%|██████████| 1000/1000 [00:00<00:00, 7510.16it/s]

103: [1.0, -2.220446049250313e-16]





dict_keys(['sorted_eigenvalues_list', 'sum_of_eigenvalues'])

In [158]:
true_eigs = sorted([eig.real for eig in true_object.calc_eigenvalues()])
true_eigs

[0.0, 0.9999999999999997]

In [159]:
true_eigs = sorted([eig.real for eig in true_object.calc_eigenvalues()], reverse=True)

for i, values in enumerate(violation_result["sorted_eigenvalues_list"]):
    fig = physicality_violation_check.make_prob_dist_histogram(values, bin_size=0.0001)
    fig.update_layout(title=f"N={num_data[0]}, Nrep={len(values)}")
    fig.update_xaxes( title=f"Eigenvalue (i={i})")
    
    x_value = true_eigs[i]
    fig.add_shape(
                type="line",
                line_color="red",
                line_width=2,
                opacity=0.5,
                x0=x_value,
                x1=x_value,
                xref="x",
                y0=0,
                y1=1,
                yref="paper",
            )
    fig.show()

In [160]:
value_list = violation_result["sum_of_eigenvalues"]["less_than_zero"]
fig = physicality_violation_check.make_prob_dist_histogram(value_list, bin_size=0.0001, annotation_vlines=[0])

# TODO: iterationsとnum_dataを与える方法
n_rep = len(violation_result["sorted_eigenvalues_list"][0])
n = num_data[0]
title = f"N={n}, Nrep={n_rep}, Number of Unphysical estimates={len(value_list)}"
fig.update_layout(title=title)
fig.update_xaxes( title=f"Sum of unphysical eigenvalues (<0)")
fig.show()

value_list = violation_result["sum_of_eigenvalues"]["greater_than_one"]
fig = physicality_violation_check.make_prob_dist_histogram(value_list, bin_size=0.0001, annotation_vlines=[1])

n_rep = len(violation_result["sorted_eigenvalues_list"][0])
n = num_data[0]
title = f"N={n}, Nrep={n_rep}, Number of Unphysical estimates={len(value_list)}"
fig.update_layout(title=title)  # TODO
fig.update_xaxes( title=f"Sum of unphysical eigenvalues (>1)")
fig.show()

# True Object = State([ 1/np.sqrt(2), 1/np.sqrt(6), 1/np.sqrt(6), 1/np.sqrt(6) ])

In [161]:
e_sys = ElementalSystem(0, get_normalized_pauli_basis())
c_sys = CompositeSystem([e_sys])

vec = np.array([1/np.sqrt(2), 1/np.sqrt(6), 1/np.sqrt(6), 1/np.sqrt(6)],dtype=np.float64)
true_object = State(c_sys, vec)
num_data = [1000]  # N
iterations = 1000  # Nrep

In [162]:
povm_x = get_x_measurement(c_sys)
povm_y = get_y_measurement(c_sys)
povm_z = get_z_measurement(c_sys)
povms = [povm_x, povm_y, povm_z]

In [163]:
name = "case_2"
param_affine_est_linear = estimate(name, true_object, num_data, iterations, on_para_eq_constraint=True)
param_linear_est_linear = estimate(name, true_object, num_data, iterations, on_para_eq_constraint=False)

param_affine_est_linear = [p[0] for p in param_affine_est_linear]
param_linear_est_linear = [p[0] for p in param_linear_est_linear]

100%|██████████| 1000/1000 [00:11<00:00, 83.93it/s]
  1%|          | 7/1000 [00:00<00:15, 62.54it/s]

time(s)=11.921917200088501


100%|██████████| 1000/1000 [00:10<00:00, 90.94it/s]

time(s)=10.99937391281128





## on_para_eq_constraint = True

In [164]:
violation_result = physicality_violation_check.get_physicality_violation_result_for_state(param_affine_est_linear)
violation_result.keys()

dict_keys(['violation_list'])

In [165]:
value_list = violation_result["violation_list"]
physicality_violation_check.make_prob_dist_histogram(value_list, x_range=(-1, 2), bin_size=0.01)

## on_para_eq_constraint = False

In [166]:
violation_result = physicality_violation_check.get_physicality_violation_result_for_state(param_linear_est_linear)
violation_result.keys()

100%|██████████| 1000/1000 [00:00<00:00, 8015.24it/s]


0: [0.9987273804394546, 0.0012726195605456778]
0: [0.9987273804394546, 0.0012726195605456778]
2: [0.9873089369178452, 0.012691063082155089]
2: [0.9873089369178452, 0.012691063082155089]
3: [0.9943864884885103, 0.005613511511489588]
3: [0.9943864884885103, 0.005613511511489588]
4: [0.9833063210842581, 0.01669367891574171]
4: [0.9833063210842581, 0.01669367891574171]
9: [0.9994126550258814, 0.0005873449741184082]
9: [0.9994126550258814, 0.0005873449741184082]
12: [0.9890674800065936, 0.010932519993406099]
12: [0.9890674800065936, 0.010932519993406099]
13: [0.9940313755218388, 0.005968624478161305]
13: [0.9940313755218388, 0.005968624478161305]
14: [0.987532563015026, 0.01246743698497419]
14: [0.987532563015026, 0.01246743698497419]
17: [0.9926611411507913, 0.007338858849208327]
17: [0.9926611411507913, 0.007338858849208327]
19: [0.9911629057654903, 0.008837094234509357]
19: [0.9911629057654903, 0.008837094234509357]
20: [0.9962146712865312, 0.00378532871346885]
20: [0.9962146712865312, 0

636: [0.9898407904615543, 0.010159209538445988]
636: [0.9898407904615543, 0.010159209538445988]
637: [0.9873253533318374, 0.012674646668162533]
637: [0.9873253533318374, 0.012674646668162533]
642: [0.9959627002104091, 0.004037299789590937]
642: [0.9959627002104091, 0.004037299789590937]
643: [0.9969185043847733, 0.003081495615226612]
643: [0.9969185043847733, 0.003081495615226612]
648: [0.9870533851642959, 0.01294661483570391]
648: [0.9870533851642959, 0.01294661483570391]
649: [0.999426671294195, 0.0005733287058047131]
649: [0.999426671294195, 0.0005733287058047131]
653: [0.9838532835478129, 0.016146716452187175]
653: [0.9838532835478129, 0.016146716452187175]
654: [0.9772274091038777, 0.02277259089612202]
654: [0.9772274091038777, 0.02277259089612202]
655: [0.9977770585312267, 0.002222941468773417]
655: [0.9977770585312267, 0.002222941468773417]
658: [0.9914580348310522, 0.008541965168947415]
658: [0.9914580348310522, 0.008541965168947415]
662: [0.9892800833878288, 0.0107199166121719

dict_keys(['sorted_eigenvalues_list', 'sum_of_eigenvalues'])

In [167]:
true_eigs = sorted([eig.real for eig in true_object.calc_eigenvalues()], reverse=True)

for i, values in enumerate(violation_result["sorted_eigenvalues_list"]):
    fig = physicality_violation_check.make_prob_dist_histogram(values, bin_size=0.0001)
    fig.update_layout(title=f"Nrep={len(values)}")
    fig.update_xaxes( title=f"Eigenvalue (i={i})")
    
    x_value = true_eigs[i]
    fig.add_shape(
                type="line",
                line_color="red",
                line_width=2,
                opacity=0.5,
                x0=x_value,
                x1=x_value,
                xref="x",
                y0=0,
                y1=1,
                yref="paper",
            )
    fig.show()

In [30]:
value_list = violation_result["sum_of_eigenvalues"]["less_than_zero"]
fig = physicality_violation_check.make_prob_dist_histogram(value_list, bin_size=0.0001, annotation_vlines=[0])

# TODO: iterationsとnum_dataを与える方法
n_rep = len(violation_result["sorted_eigenvalues_list"][0])
n = num_data[0]
title = f"N={n}, Nrep={n_rep}, Number of Unphysical estimates={len(value_list)}"
fig.update_layout(title=title)
fig.update_xaxes( title=f"Sum of unphysical eigenvalues (<0)")
fig.show()

value_list = violation_result["sum_of_eigenvalues"]["greater_than_one"]
fig = physicality_violation_check.make_prob_dist_histogram(value_list, bin_size=0.0001, annotation_vlines=[1])

n_rep = len(violation_result["sorted_eigenvalues_list"][0])
n = num_data[0]
title = f"N={n}, Nrep={n_rep}, Number of Unphysical estimates={len(value_list)}"
fig.update_layout(title=title)  # TODO
fig.update_xaxes( title=f"Sum of unphysical eigenvalues (>1)")
fig.show()

# True Object = State([1/np.sqrt(2), 0, 0, 0])

In [205]:
e_sys = ElementalSystem(0, get_normalized_pauli_basis())
c_sys = CompositeSystem([e_sys])

vec = np.array([1/np.sqrt(2), 0, 0, 0],dtype=np.float64)
true_object = State(c_sys, vec)
num_data = [1000]  # N
iterations = 1000  # Nrep

In [206]:
true_object.to_density_matrix()

array([[0.5+0.j, 0. +0.j],
       [0. +0.j, 0.5+0.j]])

In [207]:
povm_x = get_x_measurement(c_sys)
povm_y = get_y_measurement(c_sys)
povm_z = get_z_measurement(c_sys)
povms = [povm_x, povm_y, povm_z]

In [208]:
name = "case_3"
param_affine_est_linear = estimate(name, true_object, num_data, iterations, on_para_eq_constraint=True)
param_linear_est_linear = estimate(name, true_object, num_data, iterations, on_para_eq_constraint=False)

param_affine_est_linear = [p[0] for p in param_affine_est_linear]
param_linear_est_linear = [p[0] for p in param_linear_est_linear]

100%|██████████| 1000/1000 [00:10<00:00, 91.13it/s]
  1%|          | 9/1000 [00:00<00:11, 86.26it/s]

time(s)=10.980865001678467


100%|██████████| 1000/1000 [00:11<00:00, 88.93it/s]


time(s)=11.248534679412842


## on_para_eq_constraint = True

In [209]:
violation_result = physicality_violation_check.get_physicality_violation_result_for_state(param_affine_est_linear)
violation_result.keys()

dict_keys(['violation_list'])

In [210]:
value_list = violation_result["violation_list"]
physicality_violation_check.make_prob_dist_histogram(value_list, x_range=(-1, 2), bin_size=0.01)

## on_para_eq_constraint = False

In [221]:
def f(x: float, true_eigenvalue: float) -> float:
    rho_2 = (true_eigenvalue * (1- true_eigenvalue)) / 1000  # ρ**2
    w_1 = (x - true_eigenvalue) ** 2
    w_2 = - (1 / rho_2) * ((x - true_eigenvalue) ** 2)
    
    y = 2 * np.sqrt((2 * np.pi) / rho_2)
    y = y * w_1
    y = y * np.exp(w_2)
    
    return y

In [222]:
print(true_object.calc_eigenvalues()[0])
f(0.5, true_object.calc_eigenvalues()[0])

(0.4999999999999999+0j)


(3.90814244959578e-30+0j)

In [223]:
true_object.calc_eigenvalues()

array([0.5+0.j, 0.5+0.j])

In [224]:
violation_result = physicality_violation_check.get_physicality_violation_result_for_state(param_linear_est_linear)
violation_result.keys()

100%|██████████| 1000/1000 [00:00<00:00, 5439.96it/s]


0: [0.515362291495737, 0.48463770850426247]
0: [0.515362291495737, 0.48463770850426247]
1: [0.5341320963317523, 0.46586790366824754]
1: [0.5341320963317523, 0.46586790366824754]
2: [0.509848857801796, 0.4901511421982039]
2: [0.509848857801796, 0.4901511421982039]
3: [0.5191311264697088, 0.48086887353029095]
3: [0.5191311264697088, 0.48086887353029095]
4: [0.5253771550808992, 0.474622844919101]
4: [0.5253771550808992, 0.474622844919101]
5: [0.5208086520466846, 0.4791913479533149]
5: [0.5208086520466846, 0.4791913479533149]
6: [0.5050990195135927, 0.4949009804864072]
6: [0.5050990195135927, 0.4949009804864072]
7: [0.5211187120819429, 0.478881287918057]
7: [0.5211187120819429, 0.478881287918057]
8: [0.5327566787083184, 0.46724332129168156]
8: [0.5327566787083184, 0.46724332129168156]
9: [0.5054772255750516, 0.4945227744249482]
9: [0.5054772255750516, 0.4945227744249482]
10: [0.5227156333832008, 0.47728436661679874]
10: [0.5227156333832008, 0.47728436661679874]
11: [0.5366606055596466, 0.4

297: [0.5203224014329018, 0.47967759856709835]
297: [0.5203224014329018, 0.47967759856709835]
298: [0.5240416305603427, 0.47595836943965736]
298: [0.5240416305603427, 0.47595836943965736]
299: [0.5294278779391242, 0.4705721220608755]
299: [0.5294278779391242, 0.4705721220608755]
300: [0.5290860791444976, 0.4709139208555017]
300: [0.5290860791444976, 0.4709139208555017]
301: [0.5349284983931457, 0.46507150160685373]
301: [0.5349284983931457, 0.46507150160685373]
302: [0.5464112055434889, 0.4535887944565105]
302: [0.5464112055434889, 0.4535887944565105]
303: [0.5250399680510974, 0.47496003194890185]
303: [0.5250399680510974, 0.47496003194890185]
304: [0.5240831891575845, 0.4759168108424153]
304: [0.5240831891575845, 0.4759168108424153]
305: [0.5053851648071345, 0.4946148351928654]
305: [0.5053851648071345, 0.4946148351928654]
306: [0.5173493515728974, 0.4826506484271024]
306: [0.5173493515728974, 0.4826506484271024]
307: [0.5063245553203364, 0.4936754446796631]
307: [0.5063245553203364, 

762: [0.5123288280059377, 0.48767117199406174]
763: [0.5094339811320568, 0.4905660188679432]
763: [0.5094339811320568, 0.4905660188679432]
764: [0.5429534631898292, 0.457046536810171]
764: [0.5429534631898292, 0.457046536810171]
765: [0.5365650105975643, 0.46343498940243544]
765: [0.5365650105975643, 0.46343498940243544]
766: [0.520856653614614, 0.4791433463853855]
766: [0.520856653614614, 0.4791433463853855]
767: [0.5168226038412607, 0.483177396158739]
767: [0.5168226038412607, 0.483177396158739]
768: [0.5282842712474618, 0.4717157287525379]
768: [0.5282842712474618, 0.4717157287525379]
769: [0.5214009345590326, 0.47859906544096736]
769: [0.5214009345590326, 0.47859906544096736]
770: [0.5148996644257514, 0.48510033557424864]
770: [0.5148996644257514, 0.48510033557424864]
771: [0.5289999999999999, 0.471]
771: [0.5289999999999999, 0.471]
772: [0.5080622577482984, 0.4919377422517013]
772: [0.5080622577482984, 0.4919377422517013]
773: [0.5072801098892805, 0.49271989011071937]
773: [0.5072

dict_keys(['sorted_eigenvalues_list', 'sum_of_eigenvalues'])

In [225]:
true_eigs = sorted([eig.real for eig in true_object.calc_eigenvalues()], reverse=True)
test_list = []
for i, values in enumerate(violation_result["sorted_eigenvalues_list"]):
    fig = physicality_violation_check.make_prob_dist_histogram(values, bin_size=0.0001)
    # True Eigenvalue
    x_value = true_eigs[i]
    fig.add_shape(
                type="line",
                line_color="red",
                line_width=2,
                opacity=0.5,
                x0=x_value,
                x1=x_value,
                xref="x",
                y0=0,
                y1=1,
                yref="paper",
            )
    
    # Line
    # x_list = sorted(fig.data[0]["x"])
    if i == 0:
        x_list = np.arange(0.5, 1.5, 0.0001)
    else:
        x_list = np.arange(0.5, -0.5, -0.0001)
    
    y_list = [f(x, true_eigs[i]) for x in x_list]
    line_trace = go.Scatter(x=x_list, y=y_list, line_color='rgb(0,176,246)')
    fig.add_trace(line_trace)
    test_list.append(y_list)
    
    fig.update_layout(title=f"Nrep={len(values)}")
    fig.update_xaxes( title=f"Eigenvalue (i={i})")
    fig.show()

In [219]:
# lambda_i = 1.5のとき
2 * np.sqrt((2 * np.pi) / 0.25) * np.exp(-1/0.25)

0.18364199322572883

In [151]:
value_list = violation_result["sum_of_eigenvalues"]["less_than_zero"]
fig = physicality_violation_check.make_prob_dist_histogram(value_list, bin_size=0.0001, annotation_vlines=[0])

# TODO: iterationsとnum_dataを与える方法
n_rep = len(violation_result["sorted_eigenvalues_list"][0])
n = num_data[0]
title = f"N={n}, Nrep={n_rep}, Number of Unphysical estimates={len(value_list)}"
fig.update_layout(title=title)
fig.update_xaxes( title=f"Sum of unphysical eigenvalues (<0)")
fig.show()

value_list = violation_result["sum_of_eigenvalues"]["greater_than_one"]
fig = physicality_violation_check.make_prob_dist_histogram(value_list, bin_size=0.0001, annotation_vlines=[1])

n_rep = len(violation_result["sorted_eigenvalues_list"][0])
n = num_data[0]
title = f"N={n}, Nrep={n_rep}, Number of Unphysical estimates={len(value_list)}"
fig.update_layout(title=title)  # TODO
fig.update_xaxes( title=f"Sum of unphysical eigenvalues (>1)")
fig.show()

In [215]:
true_object.calc_eigenvalues()

array([0.5+0.j, 0.5+0.j])

In [216]:
f(0.5, true_object.calc_eigenvalues()[0])

(1.2358631561112461e-31+0j)

In [55]:
total = 0
count = 0

for eigs in violation_result["sorted_eigenvalues_list"]:
    total += sum(eigs)
    count += len(eigs)
    
total / count

0.49999999999999967

In [77]:
sum_of_vecs = np.zeros(4, )

for estimated in param_linear_est_linear:
    sum_of_vecs += estimated.vec
   
mean_of_vecs = sum_of_vecs / len(param_linear_est_linear)

mean_of_vecs

array([ 7.07106781e-01,  8.44285497e-04,  4.83661038e-04, -2.27688384e-04])

In [81]:
", ".join([f"{v:f}" for v in mean_of_vecs])

'0.707107, 0.000844, 0.000484, -0.000228'

In [80]:
true_object.vec

array([0.70710678, 0.        , 0.        , 0.        ])

In [83]:
np.mean(np.array([est.vec for est in param_linear_est_linear]), axis=0)

array([ 7.07106781e-01,  8.44285497e-04,  4.83661038e-04, -2.27688384e-04])

In [86]:
mean_of_estimated_state = State(vec=mean_of_vecs,c_sys=c_sys)

In [89]:
mean_of_estimated_state.calc_eigenvalues()

array([0.49929339+2.6392094e-20j, 0.50070661-2.6392094e-20j])

In [90]:
[v.real for v in mean_of_estimated_state.calc_eigenvalues()]

[0.49929339261255284, 0.5007066073874543]

In [103]:
identify = dt.datetime.now().strftime("%Y%m%d_%H%M%S")

In [105]:
identify = dt.datetime.now().strftime("%Y%m%d_%H%M%S")

path = f"data/param_linear_est_linear_{identify}.pkl"
with open(path, "wb") as f:
    pickle.dump(param_linear_est_linear, f)