In [None]:
import jax
from jax import config
config.update("jax_enable_x64", True)

from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np

from viperleed_jax.tensor_calculator import benchmark_calculator
from viperleed_jax.from_state import calculator_from_state
import importlib

jax.devices()

In [None]:
import numpy as np
import jax
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from scipy import interpolate

In [None]:
from viperleed.calc.files.displacements import readDISPLACEMENTS
from viperleed.calc import LOGGER as logger
from viperleed.calc.files.phaseshifts import readPHASESHIFTS
from viperleed.calc.files.iorfactor import beamlist_to_array

In [None]:
%env VIPERLEED_TENSORLEED=/Users/alexander/GitHub/viperleed-tensorleed/tensorleed/

In [None]:
input_base_path = Path('../tests/test_data/Cu_111')
large_file_path = Path('../large_files/')

origin_path = input_base_path /'dynamic_l_max'
tensor_path = large_file_path / 'Cu_111' / 'dynamic_l_max' /'Tensors' / 'Tensors_001.zip'

# origin_path = input_base_path /'fixed_l_max'
# tensor_path = large_file_path / 'Cu_111' / 'fixed_l_max' /'Tensors' / 'Tensors_001.zip'

(
    calculator,
    slab,
    rpars,
    ref_data,
    phaseshifts,
    base_scatterers,
    disp_file_from_state,
    ps_from_state,
) = calculator_from_state(
    origin_path,
    tensor_path,
    l_max=10,
    batch=100,
    recalculate_ref_t_matrices=False,
)

In [None]:
calculator.R([0.5, 0.5, 0.5])  # Array(0.00058721, dtype=float64)

In [None]:
ps_from_state.geo_tree

In [None]:
calculator.ref_t_matrices.shape

In [None]:
ps_from_state.vib_tree.ref_vib_amps

In [None]:
calculator.R([0.5, 0.5, 0.5])

In [None]:
ps_from_state.vib_tree

In [None]:
ps_from_state.vib_tree.n_static_t_matrices

In [None]:
from viperleed_jax import base_scatterers

In [None]:
from viperleed_jax.files.displacements import regex
from viperleed_jax.files.displacements import lines
from viperleed_jax.files.displacements import file
from viperleed_jax.files.displacements import reader
from viperleed_jax.parameters import linear_transformer
from viperleed_jax.parameters import vib_parameters
from viperleed_jax.parameters import geo_parameters
from viperleed_jax.parameters import meta_parameters
from viperleed_jax.parameters import occ_parameters
from viperleed_jax.parameters import hierarchical_linear_tree
from viperleed_jax import parameter_space

In [None]:
importlib.reload(base_scatterers)
importlib.reload(regex)
importlib.reload(lines)
importlib.reload(reader)
importlib.reload(file)
importlib.reload(linear_transformer)
importlib.reload(hierarchical_linear_tree)
importlib.reload(vib_parameters)
importlib.reload(geo_parameters)
importlib.reload(meta_parameters)
importlib.reload(occ_parameters)
importlib.reload(parameter_space)

In [None]:
path = Path(
    "/Users/alexander/GitHub/on-the-fly-deltas/tests/test_data/displacements/Cu_111/DISPLACEMENTS_constrain_z"
)
f = file.DisplacementsFile()
f.read(path)

search_block = f.blocks[0]
offsets_block = f.offsets_block()

In [None]:
ps = parameter_space.ParameterSpace(base_scatterers.BaseScatterers(slab), rpars)
ps.apply_displacements(offsets_block, search_block)
print(ps.info)

In [None]:
for leaf in ps.occ_tree.leaves:
    print(leaf._bounds)

In [None]:
ps.occ_weight_transformer.biases

In [None]:
ps.vib_tree

In [None]:
calculator.R([0.0, 0.5, 0.5])

In [None]:
res = calculator.intensity([0.5, 0.5, 0.5])

In [None]:
%matplotlib inline

In [None]:
beam = 6
plt.plot(res[:, beam])
plt.plot(calculator.reference_intensity[:, beam])
plt.show()

In [None]:
for node in ps.geo_tree.roots_up_to_layer(DisplacementTreeLayers.DisplacementTreeLayers.Implicit_Constraints):
    for child in node.children:
        collapsed_tansformer = child.collapse_transformer()
        print(collapsed_tansformer.weights)
        user_mask, lower, upper = child.collapse_bounds()
        print(user_mask)
        print(user_mask.sum())
        transformer = collapsed_tansformer.select_rows(user_mask)
        lower, upper = lower[user_mask], upperviperleed_jax.parameterst(transfDisplacementTreeLayersrint("---")


In [None]:
ps.geo_tree.subtree_root.collapse_transformer().boolify()([True, True])

In [None]:
a = np.array([[0.0]])
np.linalg.matrix_rank(a)

In [None]:
ps.geo_tree, ps.vib_tree, ps.occ_tree

In [None]:
from anytree.exporter import UniqueDotExporter

UniqueDotExporter(ps.geo_tree.subtree_root).to_picture("geo_subtree.pdf")

In [None]:
from anytree.exporter import UniqueDotExporter

UniqueDotExporter(ps.vib_tree.subtree_root).to_picture("vib_subtree.pdf")

In [None]:
list(sblock.sections.values())[0][0].direction._fractional

In [None]:
ps.geo_tree.leaves[0].bounds

In [None]:
print(ps.geo_tree)

In [None]:
constrain_line = lines.ConstraintLine("geo Cu_def = linked", "geo", "Cu_def", None, "linked")

In [None]:
vib_range_line = lines.VibDeltaLine("Cu_surf = -0.5, 0.5", "Cu_surf", None, -0.5, 0.5, None)

In [None]:
vib_range_line.targets

In [None]:
bs = base_scatterers.BaseScatterers(slab)
vib_subtree = vib_parameters.VibHLSubtree(bs)
geo_subtree = geo_parameters.GeoTree(bs)
occ_subtree = occ_parameters.OccTree(bs)

In [None]:
A = np.array([[1, 0, 0],
              [1, 0, 0],
              [0, 0, 1]])

In [None]:
#A = np.array([[1, 0, 0], [0, 0.5, -0.5]])

In [None]:
np.linalg.matrix_rank(A)

In [None]:
np.diag(np.linalg.qr(A).R) == 0

In [None]:
for root in vib_subtree.roots:
    root.check_bounds_valid()

In [None]:
vib_subtree.apply_bounds(vib_range_line)

In [None]:
for leaf in vib_subtree.leaves:
    print(leaf.bounds)

In [None]:
vib_subtree.apply_implicit_constraints()

In [None]:
geo_subtree.apply_explicit_constraint(constrain_line)

In [None]:
print(vib_subtree)

In [None]:
# geo_subtree.create_subtree_root()
print(geo_subtree)
print(vib_subtree)
print(occ_subtree)

In [None]:
geo_subtree.create_subtree_root()

In [None]:
tlt = geo_subtree.subtree_root.collapse_transformer()

In [None]:
geo_subtree.subtree_root.children[0].transformer.weights

In [None]:
geo_subtree.subtree_root.children[1].transformer.weights

In [None]:
tlt.weights @ np.array([False, False])

In [None]:
np.linalg.matrix_rank(tlt.weights)

In [None]:
tlt.biases

In [None]:
np.linalg.qr(tlt.weights).Q

In [None]:
np.linalg.qr(tlt.weights).R

In [None]:
np.linalg.eig(a_test.T)

In [None]:
a_test = geo_subtree.roots[1].collapse_transformer().weights

In [None]:
np.linalg.matrix_rank(a_test)

In [None]:
geo_subtree.leaves[0].parent.stacked_transformer()

In [None]:
geo_subtree.roots

In [None]:
DisplacementRange.DisplacementRange(-0.5, 0.5)viperleed_jax.parametersDisplacementRange

In [None]:
l = linear_transformer.LinearTransformer(np.eye(2), np.zeros(2))

In [None]:
l.weights

In [None]:
print(vib_subtree)

In [None]:

bs

In [None]:
bs.scatterers

In [None]:
import importlib

In [None]:
from viperleed_jax.files.displacements import file
from viperleed_jax.files.displacements import reader
from viperleed_jax.files.displacements import regex
from viperleed_jax.files.displacements import lines
importlib.reload(regex)
importlib.reload(lines)
importlib.reload(file)
importlib.reload(reader)

In [None]:
path = Path("/Users/alexander/GitHub/on-the-fly-deltas/tests/test_data/displacements/Cu_111/DISPLACEMENTS_constrain_z")
f = file.DisplacementsFile()
f.read(path)

In [None]:
f

In [None]:
b = f.blocks[0]

In [None]:
line = b.sections[list(b.sections)[-1]][0]
line1 = b.sections[list(b.sections)[-1]][1]

In [None]:
list(b.sections)[-1]

In [None]:
from itertools import compress

In [None]:
list(compress(vib_subtree.leaves, const.targets.select(bs)))


In [None]:
const = list(b.sections.values())[-1][0]


In [None]:
bs[const.targets.select(bs)]

In [None]:
b.sections[DisplacementFileSections(5)]

In [None]:
slab.ucell[:,0]

In [None]:
slab.ab_cell @ np.array([1, 0])

In [None]:
slab.ab_cell

In [None]:
slab.ab_cell.volume

In [None]:
parameter_space = ParameterSpace(slab)

## GEOMETRY
# Fix layer 0
parameter_space.geo_params.fix_layer(1, z_offset=0.)
parameter_space.geo_params.fix_layer(2, z_offset=0.)
parameter_space.geo_params.fix_layer(3, z_offset=0.)
parameter_space.geo_params.fix_layer(4, z_offset=0.)

# symmetry constrained xyz movements ± 0.15 A for layer 2
for param in [p for p in parameter_space.geo_params.terminal_params if p.bound is None]:
    param.set_bound(GeoParamBound(-0.05, + 0.05))

## VIBRATIONS
# fix *_def sites (O_def, Fe_def)
for param in [p for p in parameter_space.vib_params.terminal_params if p.site_element.site.endswith('_def')]:
    parameter_space.vib_params.fix_site_element(param.site_element, None) # None fixes to the default value


# the rest can vary ± 0.05 A
for param in [p for p in parameter_space.vib_params.terminal_params if p.site_element.site.endswith('_surf')]:
    param.set_bound(VibParamBound(-0.05, + 0.05))

## CHEMISTRY
# no free parameters
parameter_space.occ_params.remove_remaining_vacancies()

# V0R
# set ± 2 eV
parameter_space.v0r_param.set_bound(V0rParamBound(-2., +2.))

calculator.set_parameter_space(parameter_space)

In [None]:
slab.ab_cell

In [None]:
calculator.delta_amplitude(np.asarray((0.5, 0.5, 0.5, 0.5, 0.5)))

In [None]:
calculator.jit_delta_amplitude(np.asarray((0.5, 0.5, 0.5, 0.5, 0.5)))

In [None]:
calculator.R((0.5, 0.5, 0.5, 0.5, 0.5))

In [None]:
calculator.jit_R((0.5, 0.5, 0.5, 0.5, 0.5))

In [None]:
parameter_space

# Tree

In [None]:
import importlib
import viperleed_jax
from viperleed_jax import parameter_space
from viperleed_jax.parameters import vib_parameters, hierarchical_linear_tree, geo_parameters, occ_parameters
from viperleed_jax.parameters import linear_transformer


In [None]:
importlib.reload(linear_transformer)
importlib.reload(hierarchical_linear_tree)

importlib.reload(viperleed_jax)
importlib.reload(parameter_space)
importlib.reload(vib_parameters)
importlib.reload(geo_parameters)
importlib.reload(occ_parameters)

In [None]:
from anytree import RenderTree

In [None]:
from viperleed_jax.parameter_space import get_base_scatterers, get_site_elements

ase, se = get_base_scatterers(slab), get_site_elements(slab)

In [None]:
vib_subtree = vib_parameters.VibHLSubtree(slab, ase, se)
vib_subtree.create_subtree_root()
print(vib_subtree)

In [None]:
geo_subtree = geo_parameters.GeoTree(slab, ase, se)
geo_subtree.create_subtree_root()
print(geo_subtree)

In [None]:
occ_subtree = occ_parameters.OccTree(slab, ase, se)
occ_subtree.create_subtree_root()
print(occ_subtree)

# Parameter Space and DISPLACEMENTS

In [None]:
from importlib import reload
import viperleed_jax.files.displacements
from viperleed_jax.files.displacements.file import DisplacementsFile
reload(viperleed_jax.files.displacements)
reload(viperleed_jax.files.displacements.file)

In [None]:
from viperleed_jax.files.displacements.file import DisplacementsFile

path = Path('/Users/alexander/GitHub/on-the-fly-deltas/tests/test_data/displacements/DISPLACEMENTS_mixed')

df = DisplacementsFile()
df.read(path)

In [None]:
block = df.blocks[0]

In [None]:
print(block)