In [None]:
import jax
from jax import config
config.update("jax_enable_x64", True)

from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np


from viperleed_jax.tensor_calculator import TensorLEEDCalculator, benchmark_calculator
from viperleed_jax.files.phaseshifts import Phaseshifts

jax.devices()

In [None]:
%matplotlib inline

In [None]:
from viperleed_jax.from_state import calculator_from_state
from viperleed_jax.parameter_space import ParameterSpace

In [None]:
from viperleed.calc.files.displacements import readDISPLACEMENTS
from viperleed.calc import LOGGER as logger
from viperleed.calc.files.phaseshifts import readPHASESHIFTS
from viperleed.calc.files.iorfactor import beamlist_to_array

In [None]:
origin_path = Path('../tests/test_data/Fe2O3_012/converged/')
large_file_path = Path('../large_files/')
tensor_path = large_file_path / 'Fe2O3_012' / 'converged' / 'Tensors' / 'Tensors_001.zip'

(
    calculator,
    slab,
    rpars,
    ref_data,
    phaseshifts,
    base_scatterers,
    disp_file_from_state,
    ps_from_state,
) = calculator_from_state(
    origin_path,
    tensor_path,
    l_max=10,
    batch=100,
    recalculate_ref_t_matrices=False,
    displacements_file=origin_path / "DISPLACEMENTS_z",
)

In [None]:
calculator.R([0.5, 0.5, 0.5, 0.5, 0.5])
# Array(0.17105404, dtype=float64)

In [None]:
calculator.parameter_space.dynamic_t_matrix_site_elements

In [None]:
print(ps_from_state.geo_subtree)

In [None]:
ps_from_state.geo_subtree.n_dynamic_propagators

In [None]:
ps_from_state.geo_subtree.leaf_is_dynamic

In [None]:
is_dynamic = []
for leaf in ps_from_state.geo_subtree.leaves:
    dummy_transformer = (
        ps_from_state.geo_subtree.subtree_root.transformer_to_descendent(leaf)
    )
    dummy_transformer.biases = np.zeros_like(dummy_transformer.biases)
    dummy_transformer = dummy_transformer.boolify()
    input = np.full(dummy_transformer.in_dim, dtype=bool, fill_value=True)
    arr = np.array(dummy_transformer(input))
    print((arr), np.any(arr))
    is_dynamic.append(np.any(dummy_transformer(input)))


In [None]:
for t in calculator.parameter_space.dynamic_t_matrix_transformers:
    print(t.biases)

In [None]:
calculator.parameter_space.all_vib_amps_transformer.biases

In [None]:
ps_from_state.vib_subtree.leaf_order

In [None]:
base_scatterers.scatterers

In [None]:
import importlib
from viperleed_jax import parameter_space
from viperleed_jax.parameters import (
    vib_parameters,
    hierarchical_linear_tree,
    geo_parameters,
    occ_parameters,
)
from viperleed_jax.parameters import linear_transformer
from viperleed_jax import base_scatterers
from viperleed_jax.files.displacements import lines

from viperleed_jax.files.displacements import file

from viperleed_jax.files.displacements import regex
from viperleed_jax.files.displacements import lines
from viperleed_jax.files.displacements import file
from viperleed_jax.files.displacements import reader
from viperleed_jax.parameters import linear_transformer
from viperleed_jax.parameters import vib_parameters
from viperleed_jax.parameters import geo_parameters
from viperleed_jax.parameters import meta_parameters
from viperleed_jax.parameters import occ_parameters
from viperleed_jax.parameters import hierarchical_linear_tree
from viperleed_jax import parameter_space

In [None]:
importlib.reload(base_scatterers)
importlib.reload(regex)
importlib.reload(lines)
importlib.reload(reader)
importlib.reload(file)
importlib.reload(linear_transformer)
importlib.reload(hierarchical_linear_tree)
importlib.reload(vib_parameters)
importlib.reload(geo_parameters)
importlib.reload(meta_parameters)
importlib.reload(occ_parameters)
importlib.reload(parameter_space)

In [None]:
path = Path(
    "/Users/alexander/GitHub/on-the-fly-deltas/tests/test_data/displacements/Fe2O3_012/DISPLACEMENTS_wildcard"
)
f = file.DisplacementsFile()
f.read(path)

search_block = f.blocks[0]
offsets_block = f.offsets_block()

In [None]:
bs = base_scatterers.BaseScatterers(slab)
ps = parameter_space.ParameterSpace(bs, rpars)

In [None]:
ps.apply_displacements(offset_block=None, search_block=search_block)

In [None]:
calculator.set_parameter_space(ps)

In [None]:
calculator.parameter_space.dynamic_t_matrix_site_elements

In [None]:
calculator.R([5/12, 0.5, 0.5, 0.5, 0.5])

In [None]:
trafo = ps.geo_subtree.subtree_root.collapse_transformer()

In [None]:
ps.vib_subtree.n_dynamic_t_matrices

In [None]:
print(ps.info)

In [None]:
import numpy as np
from scipy.linalg import lstsq
from collections import defaultdict


def group_dependent_rows(matrix, tol=1e-10):
    """
    Group linearly dependent rows in an nxm matrix by building a dependency graph.

    Parameters:
    matrix (np.ndarray): The input matrix of shape (n, m).
    tol (float): Tolerance level for detecting linear dependence.

    Returns:
    list of lists: Each inner list contains indices of rows that are linearly dependent.
    """
    n_rows = matrix.shape[0]
    dependency_graph = defaultdict(set)

    # Build the dependency graph by testing each pair of rows
    for i in range(n_rows):
        for j in range(i + 1, n_rows):
            # Check if row j is a scalar multiple of row i
            _, residuals, rank, _ = lstsq(
                matrix[i : i + 1].T, matrix[j : j + 1].T
            )
            if rank < 1 or residuals < tol:
                dependency_graph[i].add(j)
                dependency_graph[j].add(i)

    # Find connected components in the dependency graph
    visited = set()
    dependent_groups = []

    def dfs(node, group):
        visited.add(node)
        group.append(node)
        for neighbor in dependency_graph[node]:
            if neighbor not in visited:
                dfs(neighbor, group)

    for row in range(n_rows):
        if row not in visited:
            group = []
            dfs(row, group)
            dependent_groups.append(group)

    return dependent_groups


# Example usage
matrix = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9], [4, 5, 6]])

dependent_groups = group_dependent_rows(matrix)
print("Groups of dependent rows:", dependent_groups)

In [None]:
trans = ps.vib_subtree.subtree_root.children[2].collapse_transformer()

In [None]:
trans.biases

In [None]:
trans.boolify().biases

In [None]:
ps.vib_subtree.subtree_root.children[2].collapse_bounds()

In [None]:
print(ps.geo_subtree)

In [None]:
ps.geo_subtree.dynamic_displacements_transformers()

In [None]:
ps.meta_param_subtree.collapsed_transformer()([0.75])

In [None]:
ps.vib_subtree

In [None]:
ps.vib_subtree.collapsed_transformer()([0.5])

In [None]:
explcnode = ps.vib_subtree.subtree_root.children[2].children[0]

In [None]:
explcnode.name, explcnode.transformer.weights, explcnode.transformer.biases

In [None]:
explcnode.children[0].collapse_bounds()

In [None]:
explcnode.children[0].collapse_bounds()

In [None]:
for c in explcnode.children:
    print(c.transformer.weights, c.transformer.biases)

In [None]:
def make_t(node):
    weights = node.harmonized_upper_bound - node.harmonized_lower_bound
    biases = node.harmonized_lower_bound
    return linear_transformer.LinearTransformer(weights, biases, (1,))
    

In [None]:
explcnode.children

In [None]:
collapsed_transformer = explcnode.collapse_transformer()
user_mask, lower, upper = explcnode.collapse_bounds()
user_mask, lower, upper

In [None]:
collapsed_transformer.weights

In [None]:
masked_transformer = collapsed_transformer.select_rows(user_mask)
masked_lower, masked_upper = lower[user_mask], upper[user_mask]

In [None]:
masked_lower

In [None]:
masked_upper - masked_lower

In [None]:
lower

In [None]:
upper-lower

In [None]:
ps.vib_subtree.collapsed_transformer().biases

In [None]:
ps.vib_subtree

In [None]:
for leaf in ps.vib_subtree.leaves:
    print(leaf._bounds)

In [None]:
ps.freeze()

In [None]:
ps.geo_subtree.leaf_plane_symmetry_operations

In [None]:
ps.geo_subtree.transformer_for_dynamic_propagator_inputs

In [None]:
print(ps.vib_subtree)

In [None]:
from viperleed_jax.parameters.hierarchical_linear_tree import HLSubtree
from viperleed_jax.parameters.hierarchical_linear_tree import HLLeafNode
from viperleed_jax.parameters.hierarchical_linear_tree import HLBound

In [None]:
class MetaParameterSubtree(HLSubtree):
    """Subtree for meta parameters."""

    def __init__(self):
        super().__init__()

    def build_subtree(self):
        # V0r
        self.v0r_node = V0rHLLeafNode()
        self.nodes.append(self.v0r_node)

    def read_from_rpars(self, rpars):
        # V0r
        self.v0r_node.update_bounds(rpars)
        self.create_subtree_root()

    @property
    def name(self):
        return "Meta Parameters (V0r)"

    @property
    def subtree_root_name(self):
        return "V0r (root)"


class V0rHLLeafNode(HLLeafNode):

    def __init__(self):
        dof = 1  # V0r is a single scalar parameter
        name = "V0r"
        self.bound = HLBound(1)
        super().__init__(dof=dof, name=name)

    def update_bounds(self, rpars):
        lower, upper = rpars.IV_SHIFT_RANGE.start, rpars.IV_SHIFT_RANGE.stop
        self.bound.update_range(
            _range=(lower, upper), offset=None, enforce=True
        )

In [None]:
meta = MetaParameterSubtree()
meta.read_from_rpars(rpars)

In [None]:
print(meta)

In [None]:
t = ps.vib_subtree.roots[0].collapse_transformer()

In [None]:
rpars.IV_SHIFT_RANGE

In [None]:
t((1,))

In [None]:
ps.vib_subtree.graphical_export("fe2O3_vib.pdf")

In [None]:
ps.geo_subtree.graphical_export("fe2O3_geo.pdf")

In [None]:
fe1.bounds

In [None]:
fe1.transformer.weights, fe2.transformer.weights

In [None]:
sym = fe1.parent.collapse_transformer().weights

In [None]:
np.linalg.matrix_rank(sym)

In [None]:
z1_only = sym[[True, False, False, False, False, False],:]

In [None]:
vib_range_line = lines.VibDeltaLine(
    "Fe = -0.5, 0.5", "Fe", None, -0.5, 0.5, None
)

In [None]:
vib_range_line = lines.VibDeltaLine(
    "Fe 4 = -0.5, 0.5", "Fe", 6, -0.5, 0.5, None
)

In [None]:
vib_subtree.apply_bounds(vib_range_line)
fe1 = vib_subtree.leaves[0]
fe2 = vib_subtree.leaves[1]

In [None]:
for root in vib_subtree.roots:
    print(root.name, root.check_bounds_valid())

In [None]:
fe_def.check_bounds_valid()

In [None]:
fe_def = vib_subtree.leaves[2].parent

In [None]:
fe_def_col_trafo= fe_def.collapse_transformer()

In [None]:
fe_def.collapse_bounds()

In [None]:
thing(fe_def)

In [None]:
import scipy
scipy.linalg.solve(thing(fe_def).weights, thing(fe_def).biases)

In [None]:
fe1.bounds, fe2.bounds

In [None]:
trafo, user_set = fe1.parent.collapse_transformer()

In [None]:
trafo.weights[user_set, :]

In [None]:
fe2.bounds.fixed

In [None]:
fe1.parent.collapse_transformer().weights, fe2.parent.collapse_transformer().weights

In [None]:
vib_subtree.links

In [None]:
np.linalg.matrix_rank(z1_only)

In [None]:
np.linalg.qr(sym).Q.shape, np.linalg.qr(sym).R.shape

In [None]:
vib_subtree = vib_parameters.VibHLSubtree(bs)
vib_subtree.create_subtree_root()
print(vib_subtree)

In [None]:
geo_subtree = geo_parameters.GeoHLSubtree(bs)
geo_subtree.create_subtree_root()
print(geo_subtree)

In [None]:
occ_subtree = occ_parameters.OccHLSubtree(bs)
occ_subtree.create_subtree_root()
print(occ_subtree)

In [None]:
slab.ucell[:2, :2]

In [None]:
parameter_space = ParameterSpace(slab)

## GEOMETRY
# Fix layers 0 and 1
parameter_space.geo_params.fix_layer(2, z_offset=0.)
parameter_space.geo_params.fix_layer(1, z_offset=0.)

# symmetry constrained z movements ± 0.15 A for layer 2
for param in [p for p in parameter_space.geo_params.terminal_params if p.bound is None]:
    param.set_bound(GeoParamBound(-0.05, +0.05))

## VIBRATIONS
# fix *_def sites (O_def, Fe_def)
for param in [p for p in parameter_space.vib_params.terminal_params if p.site_element.site.endswith('_def')]:
    parameter_space.vib_params.fix_site_element(param.site_element, None) # None fixes to the default value

# the rest can vary ± 0.05 A
for param in [p for p in parameter_space.vib_params.terminal_params if p.site_element.site.endswith('_surf')]:
    param.set_bound(VibParamBound(-0.05, + 0.05))

## CHEMISTRY
# no free parameters
parameter_space.occ_params.remove_remaining_vacancies()

# V0R
# set ± 2 eV
parameter_space.v0r_param.set_bound(V0rParamBound(-2., +2.))

calculator.set_parameter_space(parameter_space)

In [None]:
test_p = np.array([0.375]+[0.5]*2 +  [0.5, 0.5, 0.5] + [0.5, 0.5, 0.5] + [0.5]*9)
calculator.parameter_space.expand_params(test_p)

In [None]:
amps = calculator.jit_delta_amplitude(test_p)

In [None]:
amps

In [None]:
# ref R-factor
params = np.array([0.375]+[0.5]*17)
calculator.R(params)