---

<!-- <a href="https://github.com/rraadd88/roux/blob/master/examples/roux_viz_text.ipynb"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
 -->
 
## ðŸ“ˆ Param.s

In [1]:
# | default_exp pms

In [None]:
# | export

import logging
from pathlib import Path

from roux.lib.io import read_dict, to_dict

import ast
import re

def _code_string_to_dict(code_string):
    """
    Converts a string containing Python assignment statements into a dictionary.
    
    Args:
        code_string (str): String containing Python assignment statements.
    
    Returns:
        dict: Dictionary mapping variable names to their values.
    """
    result = {}
    
    # Remove comments (lines starting with #) but preserve inline comments in values
    lines = code_string.strip().split('\n')
    cleaned_lines = []
    
    for line in lines:
        # Remove full-line comments
        if line.strip().startswith('#'):
            continue
        
        # Remove inline comments (but not inside strings)
        # Find first # that's not inside quotes
        in_single_quote = in_double_quote = False
        for i, char in enumerate(line):
            if char == "'" and not in_double_quote and (i == 0 or line[i-1] != '\\'):
                in_single_quote = not in_single_quote
            elif char == '"' and not in_single_quote and (i == 0 or line[i-1] != '\\'):
                in_double_quote = not in_double_quote
            elif char == '#' and not in_single_quote and not in_double_quote:
                line = line[:i]
                break
        
        cleaned_lines.append(line)
    
    # Join cleaned lines and split into statements
    code = '\n'.join(cleaned_lines)
    
    # Parse the code with ast
    try:
        tree = ast.parse(code)
    except SyntaxError:
        # If parsing fails, try to extract assignments manually
        return _fallback_extract_assignments(code_string)
    
    # Execute in a safe namespace
    namespace = {}
    
    # Define safe globals (only safe builtins)
    safe_globals = {
        'None': None,
        'True': True,
        'False': False,
        'dict': dict,
        'list': list,
        'tuple': tuple,
    }
    
    # Execute each assignment statement
    for node in ast.walk(tree):
        if isinstance(node, ast.Assign):
            # Get the target variable name(s)
            for target in node.targets:
                if isinstance(target, ast.Name):
                    # Evaluate the value
                    try:
                        # Compile and evaluate the expression
                        expr = ast.Expression(node.value)
                        code_obj = compile(expr, '<string>', 'eval')
                        value = eval(code_obj, {**safe_globals, **namespace}, {})
                        
                        # Store in both result dict and namespace for subsequent references
                        result[target.id] = value
                        namespace[target.id] = value
                    except:
                        # If evaluation fails, try to get the literal string
                        try:
                            value = ast.literal_eval(node.value)
                            result[target.id] = value
                            namespace[target.id] = value
                        except:
                            # Last resort: get the source as string
                            result[target.id] = ast.unparse(node.value) if hasattr(ast, 'unparse') else str(node.value)
    
    return result


def _fallback_extract_assignments(code_string):
    """
    Fallback method using regex for simple cases when ast parsing fails.
    """
    result = {}
    
    # Pattern to match variable assignments
    pattern = r'^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$'
    
    lines = code_string.strip().split('\n')
    current_var = None
    current_value_lines = []
    
    for line in lines:
        line = line.strip()
        
        # Skip comment lines
        if line.startswith('#'):
            continue
        
        # Remove inline comments
        if '#' in line:
            # Check if # is inside quotes
            in_quotes = False
            quote_char = None
            for i, char in enumerate(line):
                if char in ('"', "'"):
                    if not in_quotes:
                        in_quotes = True
                        quote_char = char
                    elif quote_char == char and (i == 0 or line[i-1] != '\\'):
                        in_quotes = False
                elif char == '#' and not in_quotes:
                    line = line[:i]
                    break
        
        # Check for variable assignment
        match = re.match(pattern, line)
        if match:
            # Save previous variable if exists
            if current_var is not None:
                value_str = '\n'.join(current_value_lines)
                try:
                    result[current_var] = ast.literal_eval(value_str)
                except:
                    result[current_var] = value_str
            
            # Start new variable
            current_var = match.group(1)
            current_value_lines = [match.group(2)]
        elif current_var is not None:
            # Continuation of multi-line value
            current_value_lines.append(line)
    
    # Save the last variable
    if current_var is not None:
        value_str = '\n'.join(current_value_lines)
        try:
            result[current_var] = ast.literal_eval(value_str)
        except:
            result[current_var] = value_str
    
    return result


def extract_pms(lines, fmt='dict'):
    """
    Extract keyword arguments from lines of code.
    
    Args:
        lines (str or list): String or list of strings containing assignment statements
        fmt (str): Output format - 'dict' for dictionary or 'list' for list of "key=value" strings
    
    Returns:
        dict or list: Depending on fmt parameter
    """
    # Convert lines to string if it's a list
    if isinstance(lines, list):
        code_string = '\n'.join(lines)
    else:
        code_string = lines
    
    # Use code_string_to_dict to parse the code
    params_dict = _code_string_to_dict(code_string)
    
    if fmt == 'dict':
        return params_dict
    else:
        # Convert to list of "key=value" strings
        parameters = []
        for key, value in params_dict.items():
            # Determine if value needs quotes
            if isinstance(value, str):
                # Check if it's already a string literal representation
                try:
                    # Try to see if it can be evaluated as a literal
                    ast.literal_eval(value)
                    # If it succeeds, it's a Python literal, keep as is
                    parameters.append(f"{key}={value}")
                except:
                    # If it fails, it's a regular string, add quotes
                    parameters.append(f"{key}='{value}'")
            else:
                # For non-strings, use repr() to get proper Python representation
                parameters.append(f"{key}={repr(value)}")
        return parameters

def expand_pms(
    pms
    ):
    """
    dict -> cli str
    """
    return ' '.join([f"--{k.replace('_','-')} {v}" for k,v in pms.items()])


## I/O
import json
import nbformat

def read_pms(notebook_path, tag='parameters'):
    """
    Extract code from cells tagged with 'parameters' in a Jupyter notebook.
    
    Args:
        notebook_path (str): Path to the .ipynb file
        tag (str): The tag to look for (default: 'parameters')
    
    Returns:
        list: List of code strings from tagged cells
    """
    # Read the notebook
    with open(notebook_path, 'r', encoding='utf-8') as f:
        notebook = nbformat.read(f, as_version=4)
        
    # Iterate through cells
    for cell in notebook.cells:
        if cell.cell_type == 'code':
            # Check if the cell has the specified tag
            cell_tags = cell.metadata.get('tags', [])
            if tag in cell_tags:
                # parameters_code.append()
                pms_str=cell.source
                break
                
    return extract_pms(pms_str)

## validators
def validate_pms(
    d: dict,
) -> bool:
    return ("input_path" in d) and ("output_path" in d)

def pre_params(
    params=None,
    inputs=None,
    output_path_base=None,
    flt_input_exists=False,
    flt_output_exists=False,
    drop_if_path_exists:str=None,
    drop_by_patterns:list=None,

    verbose=False,
    force=False,
    test1: bool = False,
    testn: int = None,
    outp=None, ## save
):
    """
    Unified pre-processing for params, used by both run_tasks_nb and run_tasks.
    Handles conversion, checks, output path inference, and filtering (including test1/testn).
    Returns a list of parameter dicts ready for execution.

    TODs:
        use to_cfgs to process inputs
    """
    # --- Handle input formats and output path inference ---
    param_list = params
    # print(len(param_list))
    if param_list is None and inputs is not None and output_path_base is not None:
        from roux.lib.sys import to_output_paths
        param_list = to_output_paths(
            inputs=inputs,
            output_path_base=output_path_base,
            encode_short=True,
            key_output_path="output_path",
            verbose=verbose,
            force=force,
        )
        # Optionally save all parameters (as in run_tasks_nb)
        for k, parameters in param_list.items():
            output_dir_path = output_path_base.split("{KEY}")[0]
            to_dict(
                parameters,
                f"{output_dir_path}/{k.split(output_dir_path)[1].split('/')[0]}/.parameters.yaml",
            )

    # print(len(param_list))
    if isinstance(param_list, str):
        param_list = read_dict(param_list)

    # print(len(param_list))
    if not param_list or (isinstance(param_list, (list, dict)) and len(param_list) == 0):
        logging.info("nothing to process. use `force`=True to rerun.")
        return []

    # print(len(param_list))
    # --- Convert dict to list if needed ---
    if isinstance(param_list, dict):
        if 'input_path' in param_list and 'output_path' in param_list:
            ## pms
            param_list=[param_list]
        else:
            if not any(['input_path' in d for d in param_list.values()]):
                logging.warning("setting keys of params as input_path s ..")
                param_list = {k: {**d, **{'input_path': k}} for k, d in param_list.items()}
            if validate_pms(list(param_list.values())[0]):
                param_list = list(param_list.values())
            else:
                raise ValueError(param_list)

    # --- Filtering by output existence, as in flt_params ---
    before = len(param_list)

    if flt_output_exists==True:
        ## keep
        print(len(param_list),end='->')
        param_list = [
            d
            for d in param_list
            if Path(d["output_path"]).exists()
        ]
        # print(len(param_list))
    else:
        ## drop
        param_list = [
            d
            for d in param_list
            if (force or not Path(d["output_path"]).expanduser().exists())
        ]
    
    if drop_if_path_exists:
        ## drop if sub-path exists
        print(len(param_list),end=f' -drop_if_path_exists-> ')
        assert isinstance(drop_if_path_exists,str), drop_if_path_exists
        param_list = [
            d
            for d in param_list
            if (force or not Path(Path(d['output_path']).with_suffix('').as_posix()+"/"+drop_if_path_exists).exists())
        ]

    if drop_by_patterns:
        ## drop if sub-path exists
        print(len(param_list),end=f' -drop_by_patterns-> ')
        print(drop_by_patterns)
        assert isinstance(drop_by_patterns,list), drop_by_patterns
        param_list = [
            d
            for d in param_list
            if (force or not any([s in d['input_path'] for s in drop_by_patterns]))
        ]

    # print(len(param_list))
    if flt_input_exists:
        print(len(param_list),end=' -> ')
        param_list = [
            d
            for d in param_list
            if (force if force else Path(d["input_path"]).expanduser().exists())
        ]
    
        
    if not force:
        if before - len(param_list) != 0:
            logging.info(
                f"parameters_list_flt reduced because force=False: {before} -> {len(param_list)}"
            )

    # --- Filtering by test1 and testn ---
    if test1:
        testn = 1
    if testn is not None:
        param_list = param_list[:testn]
        logging.warning(f"filtered to {len(param_list)} jobs ..")

    if len(param_list) == 0:
        # logging.info("No tasks remaining after filtering.")
        return []

    # --- Final assertions ---
    assert len(set([d["output_path"] for d in param_list])) == len(param_list), \
        "Duplicate output_path found in params."
    assert all([Path(d["input_path"]) != Path(d["output_path"]) if isinstance(d["input_path"],str) else True for d in param_list]), \
        "Some input_path == output_path in params."
    
    print(len(param_list))

    if not outp:
        return param_list
    else:
        to_dict(param_list,outp)

In [3]:
# Example usage with a list of lines
# Test with the multi-line string from earlier
print("\n\nTesting with the full configuration string:")
config_str = """## params
input_path=None #'../examples/inputs/20.pt'
output_path=None #'../examples/outputs/20.yaml'

## pre.
tasks_loss_weights={
## base
'recon': 1,

## specialised    
'group1': 1,

# not implemented: signal gain
# 'noise_reduce': 1,    

# ## unused
# 'effect_size' : 0, ## ~
}

kws_enc=dict(
rescale_mode='sum', # ~ conserve the sum of values for each sample
)

force=False
table_ext='pqt'"""

params_dict = extract_pms(config_str, fmt='dict')
print(f"\nExtracted {len(params_dict)} parameters")
print(params_dict)
print("\nSample of extracted parameters:")
sample_keys = ['input_path', 'edge_attr_min', 'in_channels', 'tasks_loss_weights', 
               'kws_enc', 'epochs', 'lr', 'stop_early', 'log_mode', 'dbug']

for key in sample_keys:
    if key in params_dict:
        value = params_dict[key]
        if isinstance(value, dict) and len(str(value)) > 50:
            print(f"  {key}: {type(value).__name__} with {len(value)} items")
        else:
            print(f"  {key}: {value} (type: {type(value).__name__})")



Testing with the full configuration string:

Extracted 6 parameters
{'input_path': None, 'output_path': None, 'tasks_loss_weights': {'recon': 1, 'group1': 1}, 'kws_enc': {'rescale_mode': 'sum'}, 'force': False, 'table_ext': 'pqt'}

Sample of extracted parameters:
  input_path: None (type: NoneType)
  tasks_loss_weights: {'recon': 1, 'group1': 1} (type: dict)
  kws_enc: {'rescale_mode': 'sum'} (type: dict)


In [4]:
read_pms(
    'roux_viz_io.ipynb'
)

{'output_path': 'tests/output/plot/plot_modified.png',
 'kws_plot_modified': {'y': 'petal_length'}}

## Outputs

In [1]:
from roux.workflow.io import get_source_path,to_mod
to_mod(get_source_path())

All checks passed!


'../roux/workflow/pms.py'

### Documentation
[`roux.viz.text`](https://github.com/rraadd88/roux#module-rouxvizline)