In [1]:
%matplotlib inline
from tensorscaling import scale, unit_tensor, random_tensor, marginal
import numpy as np

# Tensor scaling

Scale 3x3x3 unit tensor to certain non-uniform marginals:

In [2]:
shape = [3, 3, 3]
targets = [(.5, .25, .25), (.4, .3, .3), (.7, .2, .1)]

res = scale(unit_tensor(3, 3), targets, eps=1e-4)
res

Result(success=True, iterations=72, ...)

We can also access the scaling matrices and the final scaled state:

In [3]:
print(res.gs[0])
print(res.gs[1])
print(res.gs[2])

[[-0.14907362+0.53353097j -0.17414545-0.76946419j  1.16184749+0.33971747j]
 [-0.16411122-0.46728859j -0.39399366-0.09599178j  0.03826509+0.24356391j]
 [-0.09684196+0.32313536j -0.29397191-0.04173878j -0.68412867+0.17580445j]]
[[-0.17445207-0.1083611j   0.80000172+0.11127813j  0.49239761-0.60694596j]
 [-0.31446824-1.22283817j -0.04915454-0.0734045j  -0.009626  +0.06874997j]
 [ 0.32962734+0.05314404j  0.21233906+0.48340287j -0.46319779-0.25999026j]]
[[-0.41485806+0.70843928j  0.26659378+1.05967763j -0.57279359-0.51416254j]
 [-0.26494763+0.0530557j  -0.20821574-0.52418207j -0.29644117-0.00606545j]
 [ 0.35019175+0.07862029j -0.23464085-0.06083859j -0.02421597-0.21748385j]]


Let's now check that the W tensor *cannot* be scaled to uniform marginals:

In [4]:
shape = [2, 2, 2, 2]
W = np.zeros(shape)
W[1, 0, 0, 0] = W[0, 1, 0, 0] = W[0, 0, 1, 0] = W[0, 0, 0, 1] = .5
targets = [(.5, .5)] * 4

scale(W, targets, eps=1e-4, max_iterations=1000)

Result(success=False, iterations=1000, ...)

# Tuples of matrices and tensors

We can just as well only prescribe the desired spectra for subsystems.
Note that prescribing two out of three marginals amounts to *operator scaling*.

In [5]:
shape = [3, 3, 3]
targets = [(.4, .3, .3), (.7, .2, .1)]

res = scale(unit_tensor(3, 3), targets, eps=1e-6)
res

Result(success=True, iterations=27, ...)

Indeed, the last two marginals are as prescribed, while the first marginal is arbitrary.

In [6]:
print(marginal(res.psi, 0).round(5))
print(marginal(res.psi, 1).round(5))
print(marginal(res.psi, 2).round(5))

[[0.38071+0.j      0.01339+0.00901j 0.00572-0.00352j]
 [0.01339-0.00901j 0.31782+0.j      0.00148-0.00277j]
 [0.00572+0.00352j 0.00148+0.00277j 0.30147+0.j     ]]
[[ 0.4+0.j -0. -0.j -0. -0.j]
 [-0. +0.j  0.3+0.j  0. +0.j]
 [-0. +0.j  0. -0.j  0.3+0.j]]
[[ 0.7+0.j -0. +0.j -0. +0.j]
 [-0. +0.j  0.2+0.j -0. +0.j]
 [-0. -0.j -0. -0.j  0.1+0.j]]
