Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions varipeps/expectation/three_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
Corner_Literal = Literal["top-left", "top-right", "bottom-left", "bottom-right"]


@partial(jit, static_argnums=(5, 6))
@partial(jit, static_argnums=(5,))
def _three_site_triangle_workhorse(
top_left: jnp.ndarray,
top_right: jnp.ndarray,
bottom_left: jnp.ndarray,
bottom_right: jnp.ndarray,
gates: Tuple[jnp.ndarray, ...],
traced: Corner_Literal,
real_result: bool = False,
) -> List[jnp.ndarray]:
if traced == "top-left":
one_physical_sites = jnp.tensordot(top_left, top_right, ((3, 4, 5), (2, 3, 4)))
Expand Down Expand Up @@ -64,13 +63,7 @@ def _three_site_triangle_workhorse(

norm = jnp.trace(density_matrix)

if real_result:
return [
jnp.real(jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm)
for g in gates
]
else:
return [
return [
jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm for g in gates
]

Expand Down Expand Up @@ -128,7 +121,6 @@ def calc_three_sites_triangle_without_top_left_multiple_gates(
[],
)

real_result = all(jnp.allclose(g, g.T.conj()) for g in gates)

return _three_site_triangle_workhorse(
traced_density_matrix_top_left,
Expand All @@ -137,7 +129,6 @@ def calc_three_sites_triangle_without_top_left_multiple_gates(
density_matrix_bottom_right,
tuple(gates),
"top-left",
real_result,
)


Expand Down Expand Up @@ -233,7 +224,6 @@ def calc_three_sites_triangle_without_top_right_multiple_gates(
[],
)

real_result = all(jnp.allclose(g, g.T.conj()) for g in gates)

return _three_site_triangle_workhorse(
density_matrix_top_left,
Expand All @@ -242,7 +232,6 @@ def calc_three_sites_triangle_without_top_right_multiple_gates(
density_matrix_bottom_right,
tuple(gates),
"top-right",
real_result,
)


Expand Down Expand Up @@ -338,7 +327,6 @@ def calc_three_sites_triangle_without_bottom_left_multiple_gates(
[],
)

real_result = all(jnp.allclose(g, g.T.conj()) for g in gates)

return _three_site_triangle_workhorse(
density_matrix_top_left,
Expand All @@ -347,7 +335,6 @@ def calc_three_sites_triangle_without_bottom_left_multiple_gates(
density_matrix_bottom_right,
tuple(gates),
"bottom-left",
real_result,
)


Expand Down Expand Up @@ -443,7 +430,6 @@ def calc_three_sites_triangle_without_bottom_right_multiple_gates(
[],
)

real_result = all(jnp.allclose(g, g.T.conj()) for g in gates)

return _three_site_triangle_workhorse(
density_matrix_top_left,
Expand All @@ -452,7 +438,6 @@ def calc_three_sites_triangle_without_bottom_right_multiple_gates(
traced_density_matrix_bottom_right,
tuple(gates),
"bottom-right",
real_result,
)


Expand Down