Skip to content

Commit

Permalink
Ensure final dtype in Qobj creation functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ericgig committed Feb 8, 2024
1 parent fcb4117 commit 1316bd6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
9 changes: 6 additions & 3 deletions qutip/core/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def _f_op(n_sites, site, action, dtype=None):
oper : qobj
Qobj for destruction operator.
"""
dtype = dtype or settings.core["default_dtype"] or _data.CSR
# get `tensor` and sigma z objects
from .tensor import tensor
s_z = 2 * jmat(0.5, 'z', dtype=dtype)
Expand All @@ -614,7 +615,7 @@ def _f_op(n_sites, site, action, dtype=None):

eye = identity(2, dtype=dtype)
opers = [s_z] * site + [operator] + [eye] * (n_sites - site - 1)
return tensor(opers)
return tensor(opers).to(dtype)


def _implicit_tensor_dimensions(dimensions):
Expand Down Expand Up @@ -798,10 +799,11 @@ def position(N, offset=0, *, dtype=None):
oper : qobj
Position operator as Qobj.
"""
dtype = dtype or settings.core["default_dtype"] or _data.Dia
a = destroy(N, offset=offset, dtype=dtype)
position = np.sqrt(0.5) * (a + a.dag())
position.isherm = True
return position
return position.to(dtype)


def momentum(N, offset=0, *, dtype=None):
Expand All @@ -826,10 +828,11 @@ def momentum(N, offset=0, *, dtype=None):
oper : qobj
Momentum operator as Qobj.
"""
dtype = dtype or settings.core["default_dtype"] or _data.Dia
a = destroy(N, offset=offset, dtype=dtype)
momentum = -1j * np.sqrt(0.5) * (a - a.dag())
momentum.isherm = True
return momentum
return momentum.to(dtype)


def num(N, offset=0, *, dtype=None):
Expand Down
33 changes: 24 additions & 9 deletions qutip/core/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,9 @@ def coherent_dm(N, alpha, offset=0, method='operator', *, dtype=None):
"""
dtype = dtype or settings.core["default_dtype"] or _data.Dense
return coherent(N, alpha, offset=offset, method=method, dtype=dtype).proj()
return coherent(
N, alpha, offset=offset, method=method, dtype=dtype
).proj().to(dtype)


def fock_dm(dimensions, n=None, offset=None, *, dtype=None):
Expand Down Expand Up @@ -340,7 +342,7 @@ def fock_dm(dimensions, n=None, offset=None, *, dtype=None):
"""
dtype = dtype or settings.core["default_dtype"] or _data.Dia
return basis(dimensions, n, offset=offset, dtype=dtype).proj()
return basis(dimensions, n, offset=offset, dtype=dtype).proj().to(dtype)


def fock(dimensions, n=None, offset=None, *, dtype=None):
Expand Down Expand Up @@ -550,8 +552,10 @@ def projection(N, n, m, offset=None, *, dtype=None):
Requested projection operator.
"""
dtype = dtype or settings.core["default_dtype"] or _data.CSR
return basis(N, n, offset=offset, dtype=dtype) @ \
basis(N, m, offset=offset, dtype=dtype).dag()
return (
basis(N, n, offset=offset, dtype=dtype) @ \
basis(N, m, offset=offset, dtype=dtype).dag()
).to(dtype)


def qstate(string, *, dtype=None):
Expand Down Expand Up @@ -1154,10 +1158,15 @@ def triplet_states(*, dtype=None):
trip_states : list
2 particle triplet states
"""
dtype = dtype or settings.core["default_dtype"] or _data.Dense
return [
basis([2, 2], [1, 1], dtype=dtype),
np.sqrt(0.5) * (basis([2, 2], [0, 1], dtype=dtype) +
basis([2, 2], [1, 0], dtype=dtype)),
(
np.sqrt(0.5) * (
basis([2, 2], [0, 1], dtype=dtype) +
basis([2, 2], [1, 0], dtype=dtype)
)
).to(dtype),
basis([2, 2], [0, 0], dtype=dtype),
]

Expand All @@ -1181,12 +1190,13 @@ def w_state(N=3, *, dtype=None):
W : :obj:`.Qobj`
N-qubit W-state
"""
dtype = dtype or settings.core["default_dtype"] or _data.Dense
inds = np.zeros(N, dtype=int)
inds[0] = 1
state = basis([2]*N, list(inds), dtype=dtype)
for kk in range(1, N):
state += basis([2]*N, list(np.roll(inds, kk)), dtype=dtype)
return np.sqrt(1 / N) * state
return (np.sqrt(1 / N) * state).to(dtype)


def ghz_state(N=3, *, dtype=None):
Expand All @@ -1208,5 +1218,10 @@ def ghz_state(N=3, *, dtype=None):
G : qobj
N-qubit GHZ-state
"""
return np.sqrt(0.5) * (basis([2]*N, [0]*N, dtype=dtype) +
basis([2]*N, [1]*N, dtype=dtype))
dtype = dtype or settings.core["default_dtype"] or _data.Dense
return (
np.sqrt(0.5) * (
basis([2]*N, [0]*N, dtype=dtype) +
basis([2]*N, [1]*N, dtype=dtype)
)
).to(dtype)

0 comments on commit 1316bd6

Please sign in to comment.