In [1]:
from copy import deepcopy
import logging
import os

import xml.etree.ElementTree as ET
import numpy as np

In [2]:
if os.path.exists('./log.log'):
    os.remove('./log.log')

# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# Create handler (console and file)
log_handler_console = logging.StreamHandler()
log_handler_file = logging.FileHandler("log.log")
log_handler_console.setLevel(logging.DEBUG)
log_handler_file.setLevel(logging.DEBUG)
formatter = logging.Formatter(
    '%(asctime)s - %(name)s - %(levelname)s - %(message)s', 
    datefmt='%Y-%m-%d %H:%M:%S',
)
log_handler_console.setFormatter(formatter)
log_handler_file.setFormatter(formatter)
# Add handlers to the logger
logger.addHandler(log_handler_console)
logger.addHandler(log_handler_file)

In [None]:
INPUTS = os.listdir('./inputs')
if 'vanilla' in INPUTS:
    INPUTS.remove('vanilla')
    INPUTS.insert(0, 'vanilla')
logger.debug(f"Found {len(INPUTS)} input sources: {INPUTS}.")

PATH_INPUTS = {ip: f'./inputs/{ip}' for ip in INPUTS}
PATH_OUTPUTS = './outputs'
INPUTS_MIX_COEFFS = {'darker-nights': 0.8, 'visualv': 0.2}
logger.debug(f"Mixing coefficients: {INPUTS_MIX_COEFFS}.")

def validate_structure(input:ET, other:ET) -> bool:
    """Recursively compares the structure of two XML elements.
    
    Args:
    + input: First XML tree root.
    + other: Second XML tree root.
    
    Return: True if the structures are the same, False otherwise.
    """
    # Compare tags
    if input.tag != other.tag:
        return False

    # Compare the number of children
    if len(input) != len(other):
        return False

    # Recursively compare child elements
    for child_1, child_2 in zip(input, other):
        if not validate_structure(child_1, child_2):
            return False

    return True

def str_to_arr1d(x:str) -> np.ndarray:
    return np.array([float(i) for i in x.strip().split(sep=' ')])

def arr1d_to_str(x:list|np.ndarray) -> str:
    if isinstance(x, np.ndarray):
        x = x.tolist()
    return ' ' + ' '.join([f'{i:.4f}' for i in x])


# Check for identical .xml files in input folders and their cotents
FILE_LIST = os.listdir(PATH_INPUTS[INPUTS[0]])
for ip in INPUTS[1:]:
    try:
        assert os.listdir(PATH_INPUTS[ip]) == FILE_LIST
    except AssertionError as e:
        logger.debug(f"Inconsistent file contents between {PATH_INPUTS[ip]} and {FILE_LIST}.")
        raise e
logger.debug(f"Consistent file contents. Found {len(FILE_LIST)} files: {FILE_LIST}")
    
for f in FILE_LIST:
    try:
        assert f.endswith('.xml')
    except AssertionError as e:
        logger.debug(f"{f} is not an XML file.")
        raise e
logger.debug(f"All files are .xml.")

In [None]:
# Process each .xml file
for f in os.listdir(PATH_INPUTS[INPUTS[0]]):
    logger.debug(f"Processing {f} ".ljust(80, '='))
    # Check for XML file structure
    xml_tree = {}
    for i_ip, ip in enumerate(INPUTS):
        all_tags = []
        xml_tree[ip] = ET.parse(source=os.path.join(PATH_INPUTS[ip], f))
        if i_ip == 0:
            xml_template = deepcopy(xml_tree[ip])

    # Read data from multiple input folders
    data_in = {}
    arr_in = {}
    for ip in INPUTS:
        data_in[ip] = {}
        arr_in[ip] = {}
        for r_xml in xml_tree[ip].find('cycle').findall('region'):
            r = r_xml.attrib['name']
            data_in[ip][r] = {}
            for t_xml in r_xml:
                t = t_xml.tag
                data_in[ip][r].update({t: str_to_arr1d(t_xml.text)})
            logger.debug(f"Successfully read data from {PATH_INPUTS[ip]}/{f}/{ip}/{r}.")
    for ip in INPUTS[1:]:
        try:
            assert data_in[INPUTS[0]].keys() == data_in[ip].keys()
        except AssertionError as e:
            logger.debug(f"Inconsistent regions. {INPUTS[0]}/{f} has {data_in[INPUTS[0]].keys()} but {ip}/{f} has {data_in[ip].keys()}.")
            raise e
    logger.debug(f"Consistent regions. All {f} has regions {list(data_in[INPUTS[0]].keys())}.")

    # Compute output data
    data_out = {}
    for r in data_in[INPUTS[0]].keys():
        # Find all unique tags
        all_tags = []
        for ip in INPUTS:
            all_tags.extend(list(data_in[ip][r].keys()))
            logger.debug(f"Found {len(data_in[ip][r].keys())} tags in {f}/{ip}/{r}.")
        all_tags = list(set(all_tags))
        logger.debug(f"Filter to {len(all_tags)} unique tags across all configs {f}/<input>/{r} for inputs {INPUTS}.")
        
        data_out[r] = {}
        for t in all_tags:
            ip_includes_t = []
            for ip in INPUTS:
                if t in data_in[ip][r].keys():
                    ip_includes_t.append(ip)
            data_out[r][t] = sum(INPUTS_MIX_COEFFS[ip]*data_in[ip][r][t] for ip in ip_includes_t)/sum(INPUTS_MIX_COEFFS[ip] for ip in ip_includes_t)
            data_out[r][t] = np.round(data_out[r][t], decimals=4)
    logger.debug(f"Successfully computed output data for {f}.")

    # Compose output .xml file
    for r_xml in xml_template.getroot().find(path='cycle').findall(path='region'):
        r = r_xml.attrib['name']
        for t_xml in r_xml:
            t = t_xml.tag
            t_xml.text = arr1d_to_str(data_out[r][t])
    if not os.path.isdir(PATH_OUTPUTS):
        os.makedirs(PATH_OUTPUTS)
        logger.debug(f"Created folder {PATH_OUTPUTS}.")
    xml_template.write(
        file_or_filename=os.path.join(PATH_OUTPUTS, f),
        encoding="utf-8",
        xml_declaration=True,
    )
    logger.debug(f"Successfully exported to {PATH_OUTPUTS}/{f}.")