In [1]:
from pyscf import gto, scf, cc
import numpy as np
from jax import numpy as jnp
from jax import vmap, jvp, jit
import jax
from functools import partial

a = 2 # 2aB
nH = 4
atoms = ""
for i in range(nH):
    atoms += f"H {i*a:.5f} 0.00000 0.00000 \n"

mol = gto.M(atom=atoms, basis="sto6g", unit='bohr', spin=0, verbose=4)
mol.build()

mf = scf.UHF(mol)
mf.kernel()

nfrozen = 0
mycc = cc.CCSD(mf,frozen=nfrozen)
mycc.kernel()[0]

System: uname_result(system='Linux', node='sharmagroup-rn', release='6.14.0-33-generic', version='#33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2', machine='x86_64')  Threads 16
Python 3.11.14 (main, Oct 21 2025, 18:31:21) [GCC 11.2.0]
numpy 2.3.1  scipy 1.16.2  h5py 3.14.0
Date: Thu Oct 30 11:37:18 2025
PySCF version 2.11.0
PySCF path  /home/sharmagroup/sharmagroup/pyscf
GIT HEAD (branch master) 3d1768f5e33b144b606c3d2c81c12ee54d794501

[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 4
[INPUT] num. electrons = 4
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = bohr
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
[INPUT]  1 H      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.000000000000 Bohr   0.0
[INPUT]  2 H      1.058354421840   0.000000000000   0.0

np.float64(-0.0765994008130508)

In [2]:
eris = mycc.ao2mo(mycc.mo_coeff)
eccs = mycc.energy(mycc.t1, (0*mycc.t2[0],0*mycc.t2[1],0*mycc.t2[2]), eris)
print(mf.e_tot)
print(mf.e_tot+eccs)
mycc.t1 = (mycc.t1[0]*10,mycc.t1[1]*10)
eccs = mycc.energy(mycc.t1, (0*mycc.t2[0],0*mycc.t2[1],0*mycc.t2[2]), eris)
print(mf.e_tot+eccs)
eccsd = mycc.energy(mycc.t1, mycc.t2, eris)
print(mf.e_tot+eccsd)

-2.088692381947721
-2.0886942747955763
-2.0888799973483225
-2.165477505313518


In [3]:
def thouless_trans(t1):
    ''' thouless transformation |psi'> = exp(t1)|psi>
        gives the transformed mo_occrep in the 
        original mo basis <psi_p|psi'_i>
        t = t_ia
        t_ia = c_ik c.T_ka
        c_ik = <psi_i|psi'_k>
    '''
    q, r = jnp.linalg.qr(t1,mode='complete')
    u_ji = q
    u_ai = r.T
    u_occ = jnp.vstack((u_ji,u_ai))
    mo_t, r = jnp.linalg.qr(u_occ)
    # sgn = np.sign(r.diagonal())
    # sgn = np.sign((mo_t).diagonal())
    # # choose the mo_t s.t. has 
    # # positive olp with the original mo
    # mo_t = np.einsum("ij,j->ij", mo_t, sgn)
    return mo_t

In [5]:
options = {'n_eql': 4,
           'n_prop_steps': 50,
            'n_ene_blocks': 20,
            'n_sr_blocks': 10,
            'n_blocks': 10,
            'n_walkers': 3,
            'seed': 2,
            'walker_type': 'uhf',
            'trial': 'uccsd_pt2_ad',
            'dt':0.005,
            'free_projection':False,
            'ad_mode':None,
            'use_gpu': False,
            }

from ad_afqmc import config
from ad_afqmc.prop_unrestricted import prop_unrestricted
import time
from jax import random
# from ad_afqmc.ccsd_pt import ccsd_pt, sample_ccsd_pt
# t1 = 5 * mycc.t1
# mycc.t1 = [10*mycc.t1[0],10*mycc.t1[1]]
# ccsd_pt.prep_afqmc(mycc,chol_cut=1e-7)
prop_unrestricted.prep_afqmc(mycc,options,chol_cut=1e-6)
ham_data, ham, prop, trial, wave_data, sampler, observable, options, _ \
    = (prop_unrestricted._prep_afqmc(options))


#
# Preparing AFQMC calculation
# If you import pyscf cc modules and use MPI for AFQMC in the same script, finalize MPI before calling the AFQMC driver.
# Calculating Cholesky integrals
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (2, 2)
# Number of basis functions: 4
# Number of Cholesky vectors: 9
#
# Hostname: sharmagroup-rn
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
# Number of MPI ranks: 1
#
# norb: 4
# nelec: (2, 2)
#
# n_eql: 4
# n_prop_steps: 50
# n_ene_blocks: 20
# n_sr_blocks: 10
# n_blocks: 10
# n_walkers: 3
# seed: 2
# walker_type: uhf
# trial: uccsd_pt2_ad
# dt: 0.005
# free_projection: False
# use_gpu: False
# n_exp_terms: 6
# orbital_rotation: True
# do_sr: True
# symmetry: False
# save_walkers: False
# ene0: 0.0
# n_batch: 1
# LNO: False
# orbE: 0
# maxError: 0.001
#


In [10]:
config.setup_jax()
MPI = config.setup_comm()
init = time.time()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
seed = options["seed"]
neql = options["n_eql"]

trial_rdm1 = trial.get_rdm1(wave_data)
if "rdm1" not in wave_data:
    wave_data["rdm1"] = trial_rdm1

ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(
    ham_data, prop, trial, wave_data
)
prop_data = prop.init_prop_data(trial, wave_data, ham_data, None)

prop_data["key"] = random.PRNGKey(seed + rank)
print('init walker energy: ', prop_data['e_estimate'])
print('mf enegry: ', mf.e_tot)
print('err', mf.e_tot - prop_data['e_estimate'])
walker_up = prop_data['walkers'][0][0]
walker_dn = prop_data['walkers'][1][0]
# et1 = jnp.real(trial._calc_energy(
#     walker_up, walker_dn, ham_data, wave_data
#     ))
# print('exact T1 transformed init walker energy: ', et1)
print(trial._calc_energy(wave_data['mo_ta'],wave_data["mo_tb"],ham_data,wave_data))
eris = mycc.ao2mo(mycc.mo_coeff)
eccsd = mycc.energy(mycc.t1, (0*mycc.t2[0],0*mycc.t2[1],0*mycc.t2[2]), eris)
print('ccs energy: ', mf.e_tot+eccsd)
print('ccsd energy: ', mycc.e_tot)

# Hostname: sharmagroup-rn
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64
init walker energy:  -2.0886923811239098
mf enegry:  -2.088692381947721
err -8.238112414460375e-10
-2.0888799980569117
ccs energy:  -2.0888799973483225
ccsd energy:  -2.1652917827607716


In [96]:
from jax import lax
@jax.jit
def _tls_olp(
    walker_up: jax.Array,
    walker_dn: jax.Array,
    wave_data: dict,
) -> complex:
    '''<exp(T1)HF|walker>'''

    olp = jnp.linalg.det(wave_data["mo_ta"].T.conj() @ walker_up
        ) * jnp.linalg.det(wave_data["mo_tb"].T.conj() @ walker_dn)

    return olp

@partial(jit, static_argnums=5)
def _tls_exp1(x: float, h1_mod: jax.Array, walker_up: jax.Array,
              walker_dn: jax.Array, wave_data: dict, trial):
    '''
    unrestricted <ep(T1)HF|exp(x*h1_mod)|walker>
    '''

    walker_up_1x = walker_up + x * h1_mod[0].dot(walker_up)
    walker_dn_1x = walker_dn + x * h1_mod[1].dot(walker_dn)

    olp = _tls_olp(walker_up_1x, walker_dn_1x, wave_data)
    o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)

    return olp/o0

@partial(jit, static_argnums=5)
def _tls_exp2(x: float, chol_i: jax.Array, walker_up: jax.Array,
                walker_dn: jax.Array, wave_data: dict, trial) -> complex:
    '''
    <exp(T1)HF|exp(x*h2_mod)|walker>
    '''

    walker_up_2x = (
        walker_up
        + x * chol_i[0].dot(walker_up)
        + x**2 / 2.0 * chol_i[0].dot(chol_i[0].dot(walker_up))
    )
    walker_dn_2x = (
        walker_dn
        + x * chol_i[1].dot(walker_dn)
        + x**2 / 2.0 * chol_i[1].dot(chol_i[1].dot(walker_dn))
    )

    olp = _tls_olp(walker_up_2x,walker_dn_2x,wave_data)
    o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)
    
    return olp/o0

@partial(jit, static_argnums=3)
def _ut2_walker_olp(
     walker_up: jax.Array, walker_dn: jax.Array, wave_data: dict, trial
) -> complex:
    '''<exp(T1)HF|(t1+t2)|walker> = (t_ia G_ia + t_iajb G_iajb) * <exp(T1)HF|walker>'''
    noccA, t2AA = trial.nelec[0], wave_data["rot_t2AA"]
    noccB, t2BB = trial.nelec[1], wave_data["rot_t2BB"]
    t2AB = wave_data["rot_t2AB"]
    mo_A = wave_data['mo_ta'] # in alpha basis
    mo_B = wave_data['mo_tb'] # in beta basis
    green_a = (walker_up.dot(jnp.linalg.inv(mo_A.T.conj() @ walker_up))).T
    green_b = (walker_dn.dot(jnp.linalg.inv(mo_B.T.conj() @ walker_dn))).T
    green_a, green_b = green_a[:noccA, noccA:], green_b[:noccB, noccB:]
    o0 = _tls_olp(walker_up,walker_dn,wave_data)
    o2 = (0.5 * jnp.einsum("iajb, ia, jb", t2AA, green_a, green_a)
        + 0.5 * jnp.einsum("iajb, ia, jb", t2BB, green_b, green_b)
        + jnp.einsum("iajb, ia, jb", t2AB, green_a, green_b))
    return o2 * o0

@partial(jit, static_argnums=5)
def _ut2_exp1(x: float, h1_mod: jax.Array, walker_up: jax.Array,
                walker_dn: jax.Array, wave_data: dict, trial):
    '''
    unrestricted <ep(T1)HF|T2 exp(x*h1_mod)|walker>
    '''
    walker_up_1x = walker_up + x * h1_mod[0].dot(walker_up)
    walker_dn_1x = walker_dn + x * h1_mod[1].dot(walker_dn)
    
    olp = _ut2_walker_olp(walker_up_1x, walker_dn_1x, wave_data,trial)
    o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)

    return olp/o0

@partial(jit, static_argnums=5)
def _ut2_exp2(x: float, chol_i: jax.Array, walker_up: jax.Array,
                walker_dn: jax.Array, wave_data: dict, trial) -> complex:
    '''
    t_ia <psi_i^a|exp(x*h2_mod)|walker>
    '''

    walker_up_2x = (
        walker_up
        + x * chol_i[0].dot(walker_up)
        + x**2 / 2.0 * chol_i[0].dot(chol_i[0].dot(walker_up))
    )
    walker_dn_2x = (
        walker_dn
        + x * chol_i[1].dot(walker_dn)
        + x**2 / 2.0 * chol_i[1].dot(chol_i[1].dot(walker_dn))
    )
    
    olp = _ut2_walker_olp(walker_up_2x,walker_dn_2x,wave_data,trial)
    o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)

    return olp/o0

@partial(jit, static_argnums=4)
def d2_tls_exp2_i(chol_i,walker_up,walker_dn,wave_data,trial):
    x = 0.0
    f = lambda a: _tls_exp2(a,chol_i,walker_up,walker_dn,wave_data,trial)
    _, d2f = jax.jvp(lambda x: jax.jvp(f, [x], [1.0])[1], [x], [1.0])
    return d2f

@partial(jit, static_argnums=4)
def d2_ut2_exp2_i(chol_i,walker_up,walker_dn,wave_data,trial):
    x = 0.0
    f = lambda a: _ut2_exp2(a,chol_i,walker_up,walker_dn,wave_data,trial)
    _, d2f = jax.jvp(lambda x: jax.jvp(f, [x], [1.0])[1], [x], [1.0])
    return d2f

@partial(jit, static_argnums=4)
def d2_tls_exp2(walker_up,walker_dn,ham_data,wave_data,trial):
    norb = trial.norb
    chol = ham_data["chol"].reshape(2, -1, norb, norb)
    chol = chol.transpose(1,0,2,3)
    d2_exp2_batch = jax.vmap(d2_tls_exp2_i, in_axes=(0,None,None,None,None))
    d2_exp2s = d2_exp2_batch(chol,walker_up,walker_dn,wave_data,trial)
    h2 = jnp.sum(d2_exp2s)/2
    return h2

@partial(jit, static_argnums=4)
def d2_ut2_exp2(walker_up,walker_dn,ham_data,wave_data,trial):
    norb = trial.norb
    chol = ham_data["chol"].reshape(2, -1, norb, norb)
    chol = chol.transpose(1,0,2,3)
    d2_exp2_batch = jax.vmap(d2_ut2_exp2_i, in_axes=(0,None,None,None,None))
    d2_exp2s = d2_exp2_batch(chol,walker_up,walker_dn,wave_data,trial)
    h2 = jnp.sum(d2_exp2s)/2
    return h2

@partial(jit, static_argnums=0)
def _calc_energy_pt(trial, walker_up, walker_dn, ham_data, wave_data):
    '''
    t1 = <exp(T1)HF|walker>/<HF|walker>
    t2 = <exp(T1)HF|T1+T2|walker>/<HF|walker>
    e0 = <exp(T1)HF|h1+h2|walker>/<HF|walker>
    e1 = <exp(T1)HF|(T1+T2)(h1+h2)|walker>/<HF|walker>
    '''

    # eps=1e-4

    norb = trial.norb
    h1_mod = ham_data['h1_mod']
    chol = ham_data["chol"].reshape(2, -1, norb, norb)
    chol = chol.transpose(1,0,2,3)

    # o0 = trial._calc_overlap(walker_up,walker_dn,wave_data)

    # e0 = <exp(T1)HF|h1+h2|walker>/<HF|walker> #
    # one body
    x = 0.0
    f1 = lambda a: _tls_exp1(a,h1_mod,walker_up,walker_dn,wave_data,trial)
    t1, d_exp1_0 = jvp(f1, [x], [1.0])

    # two body
    # def scanned_fun(carry, c):
    #     eps,walker_up,walker_dn,wave_data = carry
    #     return carry, _tls_exp2(eps,c,walker_up,walker_dn,wave_data,trial)

    # _, exp2_p = lax.scan(scanned_fun, (eps,walker_up,walker_dn,wave_data), chol)
    # _, exp2_0 = lax.scan(scanned_fun, (0.0,walker_up,walker_dn,wave_data), chol)
    # _, exp2_m = lax.scan(scanned_fun, (-1.0*eps,walker_up,walker_dn,wave_data), chol)
    # d2_exp2 = (exp2_p - 2.0 * exp2_0 + exp2_m) / eps / eps

    # e0 = (d_exp1 + jnp.sum(d2_exp2) / 2.0 )

    # d_exp1 = d2_exp2 = None
    # exp2_p = exp2_0 = exp2_m = None
    d2_exp2_0 = d2_tls_exp2(walker_up,walker_dn,ham_data,wave_data,trial)
    e0 = d_exp1_0 + d2_exp2_0
    
    # e1 = <exp(T1)HF|(T1+T2)(h1+h2)|walker>/<HF|walker>
    # one body
    x = 0.0
    f1 = lambda a: _ut2_exp1(a,h1_mod,walker_up,walker_dn,wave_data,trial)
    t2, d_exp1_1 = jvp(f1, [x], [1.0])

    # two body
    # def scanned_fun(carry, c):
    #     eps,walker_up,walker_dn,wave_data = carry
    #     return carry, _ut2_exp2(eps,c,walker_up,walker_dn,wave_data,trial)

    # _, exp2_p = lax.scan(scanned_fun, (eps,walker_up,walker_dn,wave_data), chol)
    # _, exp2_0 = lax.scan(scanned_fun, (0.0,walker_up,walker_dn,wave_data), chol)
    # _, exp2_m = lax.scan(scanned_fun, (-1.0*eps,walker_up,walker_dn,wave_data), chol)
    # d2_exp2 = (exp2_p - 2.0 * exp2_0 + exp2_m) / eps / eps

    d2_exp2_1 = d2_ut2_exp2(walker_up,walker_dn,ham_data,wave_data,trial)
    e1 = d_exp1_1 + d2_exp2_1

    # o0 = self._calc_overlap(walker_up,walker_dn,wave_data)
    return jnp.real(t1), jnp.real(t2), jnp.real(e0), jnp.real(e1)

In [33]:
eris = mycc.ao2mo(mycc.mo_coeff)
eccs = mycc.energy(mycc.t1,(0*mycc.t2[0],0*mycc.t2[1],0*mycc.t2[2]), eris)
eccs = mf.e_tot + eccs
print('ccs energy: ', eccs)
eccsd = mycc.energy(mycc.t1, mycc.t2, eris)
eccsd = mf.e_tot + eccsd
print('ccsd energy: ', eccsd)

ccs energy:  -2.0888799973483225
ccsd energy:  -2.165477505313518


In [107]:
h0 = ham_data['h0']
walker_up, walker_dn = prop_data['walkers'][0][0], prop_data['walkers'][1][0]
import time
start = time.perf_counter()
t1,t2,e0,e1 = _calc_energy_pt(trial, walker_up, walker_dn, ham_data, wave_data)
end = time.perf_counter()
print(t1,t2,e0,e1)
print(h0 + 1/t1 * e0 )
print(h0 + 1/t1 * e0 - eccs)
print(h0 + 1/t1 * e0 + 1/t1 * e1 - 1/t1**2 * t2 * e0)
print(h0 + 1/t1 * e0 + 1/t1 * e1 - 1/t1**2 * t2 * e0 - eccsd)
print(end-start)

0.9942401330066222 0.0 -4.2310352819506605 -0.07615631736724715
-2.0888799980569144
-7.085918518612289e-10
-2.165477506887047
-1.573529306853061e-09
0.0009604920014680829


In [66]:
@partial(jit, static_argnums=4)
def d2_exp2_olp(walker_up,walker_dn,ham_data,wave_data,trial):
    eps = 1e-4
    norb = trial.norb
    chol = ham_data["chol"].reshape(2, -1, norb, norb)
    chol = chol.transpose(1,0,2,3)
    # two body
    def scanned_fun(carry, c):
        eps,walker_up,walker_dn,wave_data = carry
        return carry, _tls_exp2(eps,c,walker_up,walker_dn,wave_data,trial)

    _, exp2_p = lax.scan(scanned_fun, (eps,walker_up,walker_dn,wave_data), chol)
    _, exp2_0 = lax.scan(scanned_fun, (0.0,walker_up,walker_dn,wave_data), chol)
    _, exp2_m = lax.scan(scanned_fun, (-1.0*eps,walker_up,walker_dn,wave_data), chol)
    d2_exp2 = (exp2_p - 2.0 * exp2_0 + exp2_m) / eps / eps
    h2 = jnp.sum(d2_exp2)/2
    
    return h2

In [67]:
@partial(jit, static_argnums=4)
def d2_exp2_i(chol_i,walker_up,walker_dn,wave_data,trial):
    x = 0.0
    f = lambda a: _tls_exp2(a,chol_i,walker_up,walker_dn,wave_data,trial)
    _, d2f = jax.jvp(lambda x: jax.jvp(f, [x], [1.0])[1], [x], [1.0])
    return d2f

@partial(jit, static_argnums=4)
def d2_exp2_new(walker_up,walker_dn,ham_data,wave_data,trial):
    norb = trial.norb
    chol = ham_data["chol"].reshape(2, -1, norb, norb)
    chol = chol.transpose(1,0,2,3)
    d2_exp2_batch = jax.vmap(d2_exp2_i, in_axes=(0,None,None,None,None))
    d2_exp2s = d2_exp2_batch(chol,walker_up,walker_dn,wave_data,trial)
    h2 = jnp.sum(d2_exp2s)/2
    
    return h2

In [82]:
import time

start = time.perf_counter()
result_old = d2_exp2_olp(walker_up,walker_dn,ham_data,wave_data,trial)
end = time.perf_counter()
print("Old version:", end - start, "s")
print(result_old)

start = time.perf_counter()
result_new = d2_exp2_new(walker_up,walker_dn,ham_data,wave_data,trial)
end = time.perf_counter()
print("New version:", end - start, "s")
print(result_new)

Old version: 0.010482298999704653 s
(3.9853718214466483+0j)
New version: 0.00014304000069387257 s
(3.9853717449989774+0j)


In [38]:
exp2_0

Array([0.99424013+0.j, 0.99424013+0.j, 0.99424013+0.j, 0.99424013+0.j,
       0.99424013+0.j, 0.99424013+0.j, 0.99424013+0.j, 0.99424013+0.j,
       0.99424013+0.j], dtype=complex128)

In [37]:
# chol = chol.transpose(1,0,2,3)
# def scanned_fun(carry, c):
#     eps,walker_up,walker_dn,wave_data = carry
#     return carry, _tls_exp2(eps,c,walker_up,walker_dn,wave_data,trial)

# exp2_0 = lax.scan(scanned_fun, (0.0,walker_up,walker_dn,wave_data), chol)
tls_exp2_batch = jax.vmap(_tls_exp2, in_axes=(None,0,None,None,None,None))
print(tls_exp2_batch(0,chol,walker_up,walker_dn,wave_data,trial))

[0.99424013+0.j 0.99424013+0.j 0.99424013+0.j 0.99424013+0.j
 0.99424013+0.j 0.99424013+0.j 0.99424013+0.j 0.99424013+0.j
 0.99424013+0.j]


In [60]:
chol_0 = chol[0]
x = 0.0
y = 0.0
f = lambda a: _tls_exp2(a,chol_0,walker_up,walker_dn,wave_data,trial)
# _, df = jvp(f, [x], [1.0])
_, d2f = jax.jvp(lambda x: jax.jvp(f, [x], [1.0])[1], [x], [1.0])
# df = jax.grad(f)
# print(df)
print(d2f)
# d2f = jax.grad(df)

# print(f_double_prime(0.0))

(3.5938839146231603+0j)


In [None]:
@partial(jit, static_argnums=4)
def d2_exp2_i(chol_i,walker_up,walker_dn,wave_data,trial):
    x = 0.0
    f = lambda a: _tls_exp2(a,chol_i,walker_up,walker_dn,wave_data,trial)
    _, d2f = jax.jvp(lambda x: jax.jvp(f, [x], [1.0])[1], [x], [1.0])
    return d2f

@partial(jit, static_argnums=4)
def d2_exp2_new(chol,walker_up,walker_dn,wave_data,trial):
    d2_exp2_batch = jax.vmap(d2_exp2_i, in_axes=(0,None,None,None,None))
    d2_exp2s = d2_exp2_batch(chol,walker_up,walker_dn,wave_data,trial)
    h2 = jnp.sum(d2_exp2s)/2
    return h2

In [None]:
import time

start = time.perf_counter()
result_old = f_old()
end = time.perf_counter()
print("Old version:", end - start, "s")

start = time.perf_counter()
result_new = f_new()
end = time.perf_counter()
print("New version:", end - start, "s")

In [None]:
def vector_tls_exp2(x,chol,walker_up,walker_dn,wave_data,trial):
    tls_exp2_batch = jax.vmap(_tls_exp2, in_axes=(None,0,None,None,None,None))
    exp2 = tls_exp2_batch(x,chol,walker_up,walker_dn,wave_data,trial)
    return exp2

In [None]:
x = 0.0
chol = chol.transpose(1,0,2,3)

f1 = lambda a:  _tls_exp2(a,chol_i,walker_up,walker_dn,wave_data,trial)
t1, d_exp1 = jvp(f1, [x], [1.0])