In [None]:
import polars as pl
import blitzbeaver as bb
import random
from typing import Any
from Levenshtein import distance
import json

In [None]:
csv_path = "../../aptihramy/data/csv_cleaned"
beaver_folder_path = "../data/beaver_files"
json_folder_path = "../data/json_files"
verifier_dropped_cols = ["frame_idx", "enfants_chez_parents_prenom"]

In [None]:
dataframes_35_45 = [
    pl.read_csv(f"{csv_path}/{year}.csv", infer_schema_length=10000)
    for year in range(1835, 1845 + 1)
]

dataframes_60_70 = [
    pl.read_csv(f"{csv_path}/{year}.csv", infer_schema_length=10000)
    for year in range(1860, 1870 + 1)
]

dataframes_87_97 = [
    pl.read_csv(f"{csv_path}/{year}.csv", infer_schema_length=10000)
    for year in range(1887, 1897 + 1)
]

In [None]:
def clean_children(dataframes: pl.DataFrame) -> pl.DataFrame:
    for i in range(len(dataframes)):
        df = dataframes[i]
        dataframes[i] = df.with_columns(
            df["enfants_chez_parents_prenom"]
            .str.split("|")
            .list.eval(pl.element().filter(pl.element() != ""))
            .alias("enfants_chez_parents_prenom")
        )


clean_children(dataframes_35_45)
clean_children(dataframes_60_70)
clean_children(dataframes_87_97)

In [None]:
record_schema_base = bb.RecordSchema(
    [
        bb.FieldSchema("nom_rue", bb.ElementType.String),
        bb.FieldSchema("chef_prenom", bb.ElementType.String),
        bb.FieldSchema("chef_nom", bb.ElementType.String),
        bb.FieldSchema("chef_origine", bb.ElementType.String),
        bb.FieldSchema("epouse_nom", bb.ElementType.String),
        bb.FieldSchema("chef_vocation", bb.ElementType.String),
        bb.FieldSchema("enfants_chez_parents_prenom", bb.ElementType.MultiStrings),
    ]
)

In [None]:
record_schema_normalized = bb.RecordSchema(
    [
        bb.FieldSchema("nom_rue_norm", bb.ElementType.String),
        bb.FieldSchema("chef_prenom_norm", bb.ElementType.String),
        bb.FieldSchema("chef_nom_norm", bb.ElementType.String),
        bb.FieldSchema("chef_origine_norm", bb.ElementType.String),
        bb.FieldSchema("epouse_nom_norm", bb.ElementType.String),
        bb.FieldSchema("chef_vocation_norm", bb.ElementType.String),
        bb.FieldSchema("enfants_chez_parents_prenom", bb.ElementType.MultiStrings),
    ]
)

In [None]:
caching_threshold = 4

lv_substring_weights = [0.6, 0.7, 0.8]

distance_metric_lv_opti = bb.DistanceMetricConfig(
    metric="lv_opti",
    caching_threshold=caching_threshold,
    use_sigmoid=False,
)

distance_metrics = [
    bb.DistanceMetricConfig(
        metric="lv_substring",
        caching_threshold=caching_threshold,
        use_sigmoid=False,
        lv_substring_weight=w,
    )
    for w in lv_substring_weights
]

In [None]:
min_weight_ratios = [0.7]
weights = [
    [
        0.15,
        0.25,
        0.30,
        0.15,
        0.15,
        0.15,
        0.15,
    ],
]

record_scorers = [
    bb.RecordScorerConfig(
        record_scorer="weighted-average",
        weights=w,
        min_weight_ratio=ratio,
    )
    for w in weights
    for ratio in min_weight_ratios
]

In [None]:
resolver_config = bb.ResolverConfig(
    resolving_strategy="best-match",
)

In [None]:
memory_strategies = [
    "ls-median",
]
memory_configs = [bb.MemoryConfig(memory_strategy=m) for m in memory_strategies]

multistring_memory_config = [
    bb.MemoryConfig(
        memory_strategy="mw-median",
        multiword_threshold_match=0.8,
        multiword_distance_metric=distance_metrics[1],
    )
]

In [None]:
thresholds = [0.79, 0.8]
configs = [
    bb.config(
        record_schema=record_schema_base,
        distance_metric_config=d,
        record_scorer_config=r,
        resolver_config=resolver_config,
        memory_config=m,
        multistring_memory_config=mm,
        interest_threshold=t,
        limit_no_match_streak=4,
        num_threads=17,
    )
    for d in distance_metrics
    for r in record_scorers
    for m in memory_configs
    for mm in multistring_memory_config
    for t in thresholds
]
print(len(configs))

In [None]:
def aggregate_distance_averages(avgs: list[list[float]], nb_features: int):
    agg = [0] * nb_features

    for avg in avgs:

        for i, val in enumerate(avg):
            agg[i] += val

    return [v / len(avgs) for v in agg]


def avg_distance_metrics(
    graph: bb.TrackingGraph,
    dataframes: list[pl.DataFrame],
    record_schema: bb.RecordSchema,
    id: bb.ID,
) -> list[float] | None:
    """
    Computes the average distance per feature (column) across matching records in a tracker chain.

    For each frame in the tracker diagnostics, it identifies the record matching the materialized
    tracking chain and computes the average of the distance metrics per feature.

    Args:
        graph (bb.TrackingGraph): The tracking graph containing diagnostics.
        dataframes (list[pl.DataFrame]): List of Polars DataFrames used for materialization.
        record_schema (RecordSchema): The schema defining the order and number of features.
        id (ID): The tracker ID whose chain should be processed.

    Returns:
        list[float] | None: A list of average distances per feature. Returns None if the tracker
        doesn't exist or has no matching frames.
    """
    nb_features = len(record_schema.fields)
    tracker_diag = graph.diagnostics.get_tracker(id)

    if tracker_diag is None:
        return None

    materialized = graph.materialize_tracking_chain(id, dataframes, record_schema, None)

    avg = [0] * nb_features
    counts = [0] * nb_features

    # Map each frame index to the corresponding record index (excluding the first frame)
    frame_to_record = {
        m.frame_idx: m.record_idx for m in materialized.matched_frames[1:]
    }

    if len(frame_to_record) == 0:
        return None

    for frame in tracker_diag.frames:
        record_idx = frame_to_record.get(frame.frame_idx)
        if record_idx is None:
            continue

        for record in frame.records:
            if record.record_idx != record_idx:
                continue

            for i, distance in enumerate(record.distances):
                if distance is not None:
                    avg[i] += distance
                    counts[i] += 1
            break

    # Compute average for each feature
    ret = []
    for i, v in enumerate(avg):
        count = counts[i]
        if count == 0:
            count = 1
        ret.append(v / count)

    return ret


def print_column_scores(
    graph: bb.TrackingGraph, dataframes: list[pl.DataFrame], record_schema: bb.RecordSchema
):
    print("Column scores:")
    avgs = [
        avg_distance_metrics(graph, dataframes, record_schema, id)
        for id in graph.trackers_ids
    ]

    agg_avgs = aggregate_distance_averages(
        [a for a in avgs if a is not None], len(record_schema.fields)
    )
    zipped = list(zip(record_schema.fields, agg_avgs))
    for schema, avg in zipped:
        print(f"{schema.name}: {avg}")

In [None]:
def find_chain_with_length(
    graph: bb.TrackingGraph, start_idx: int, length: int
) -> None | int:
    idx = start_idx
    while idx < len(graph.trackers_ids):
        tracker_id = graph.trackers_ids[idx]
        chain = graph._raw.get_tracking_chain(tracker_id)
        if len(chain) >= length:
            return tracker_id
        idx += 1
    return None


def take_n_random_ids(n: int, graph: bb.TrackingGraph, min_length: int) -> list[int]:
    # Filter only tracker IDs with chain length >= min_length
    valid_ids = [
        tracker_id
        for tracker_id in graph.trackers_ids
        if len(graph._raw.get_tracking_chain(tracker_id)) >= min_length
    ]

    # Adjust n if there are fewer valid IDs than requested
    n = min(n, len(valid_ids))

    return random.sample(valid_ids, n)


def print_verify_df(chain: bb.MaterializedTrackingChain):
    values, columns = verify_columns_of_chain(chain)
    for i, v in enumerate(values):
        print(f"{columns[i]}: {v}")


def verify_columns_of_chain(
    chain: bb.MaterializedTrackingChain,
) -> tuple[list[float], list[str]]:
    df = chain.as_dataframe().drop(verifier_dropped_cols)

    results = []
    for col in df.columns:
        results.append(verify_column(df[col].to_list()))

    return results, df.columns


def print_aggregate_verifiers(
    agg_values: list[float], fields: list[str]
) -> list[float]:
    print("Aggregated verifiers:")
    for i, r in enumerate(agg_values):
        print(f"{fields[i]}: {r}")


def aggregate_column_verifiers(l: list[list[float]]) -> list[float]:
    results = [0] * len(l[0])
    for values in l:
        for i, v in enumerate(values):
            results[i] += v

    return [v / len(l) for v in results]


def verify_column(l: list[str | None]) -> float:
    ratio = 0
    count = 0

    for i, value in enumerate(l[:-1]):
        next = l[i + 1]
        if value is None:
            value = ""
        if next is None:
            next = ""
        max_len = max(len(value), len(next))
        if max_len == 0:
            continue
        ratio += 1 - distance(value, next) / max_len
        count += 1

    if count == 0:
        return 0.0
    return ratio / count


def verify_n_samples(
    tracking_graph: bb.TrackingGraph,
    dataframes: list[pl.DataFrame],
    schema: bb.RecordSchema,
    nb_samples: int,
    min_length: int,
) -> list[float]:
    ids = take_n_random_ids(nb_samples, tracking_graph, min_length)
    a = []

    for id in ids:
        chain = tracking_graph.materialize_tracking_chain(
            id, dataframes, schema, normalized_dataframes=None
        )
        values, _ = verify_columns_of_chain(chain)
        a.append(values)

    return aggregate_column_verifiers(a)


def print_verify_n_samples(
    tracking_graph: bb.TrackingGraph,
    dataframes: list[pl.DataFrame],
    schema: bb.RecordSchema,
    nb_samples: int,
    min_length: int,
):
    r = verify_n_samples(tracking_graph, dataframes, schema, nb_samples, min_length)
    fields = [f.name for f in schema.fields if f.name not in verifier_dropped_cols]
    print_aggregate_verifiers(r, fields)

In [None]:
def aggregate_histograms(histograms: list[list[int]]) -> list[int]:
    """
    Aggregates a list of histograms into a single histogram.
    """
    max_len = max([len(h) for h in histograms])
    result = [0] * max_len
    for h in histograms:
        for i, v in enumerate(h):
            result[i] += v
    return result


def get_start_end_years(
    graph: bb.TrackingGraph,
    dataframes: list[pl.DataFrame],
    record_schema: bb.RecordSchema,
    nb_of_years: int,
) -> tuple[list[int], list[int]]:

    start_years = [0] * nb_of_years
    end_years = [0] * nb_of_years
    for id in graph.trackers_ids:
        materialized = graph.materialize_tracking_chain(id, dataframes, record_schema)
        start_years[materialized.matched_frames[0].frame_idx] += 1
        end_years[materialized.matched_frames[-1].frame_idx] += 1

    return start_years, end_years


def get_avg_record_tracker_match(graph: bb.TrackingGraph):
    graph_metrics = bb.evaluate_tracking_graph_properties(graph._raw)

    records_match_ratios = graph_metrics.records_match_ratios[1:]
    trackers_match_ratios = graph_metrics.trackers_match_ratios[1:-1]

    avg_records_match = sum(records_match_ratios) / len(records_match_ratios)
    avg_trackers_match = sum(trackers_match_ratios) / len(trackers_match_ratios)

    return avg_records_match, avg_trackers_match

In [None]:
import gc

def compute_and_save_beaver(
    config: bb.TrackingConfig,
    record_schema: bb.RecordSchema,
    dataframes: list[pl.DataFrame],
    filepath: str,
) -> bb.TrackingGraph:
    path_graph = f"{beaver_folder_path}/{filepath}"
    graph = bb.execute_tracking(config, record_schema, dataframes, "debug")
    bb.save_beaver(path_graph, graph)
    return graph


def save_json(filepath: str, d: dict[str, Any]):
    path_json = f"{json_folder_path}/{filepath}"
    with open(path_json, "w") as f:
        json.dump(d, f, indent=2)


def build_tracking_summary(
    config: bb.TrackingConfig,
    record_schema: bb.RecordSchema,
    dataframes: list[pl.DataFrame],
    nb_samples: int,
    min_length: int,
    nb_of_years: int,
) -> dict[str, Any]:

    all_dict = {}

    all_dict["config"] = bb.serialize_tracking_config(config)
    graph = bb.execute_tracking(config, record_schema, dataframes, "info")
    
    r = verify_n_samples(graph, dataframes, record_schema, nb_samples, min_length)
    all_fields_names = [f.name for f in record_schema.fields]
    filtered_fields = [
        f_name for f_name in all_fields_names if f_name not in verifier_dropped_cols
    ]

    data_dict = {}
    data_dict["verifier"] = dict(zip(filtered_fields, r))

    avgs = [
        avg_distance_metrics(graph, dataframes, record_schema, id)
        for id in graph.trackers_ids
    ]

    agg_avgs = aggregate_distance_averages(
        [a for a in avgs if a is not None], len(all_fields_names)
    )

    data_dict["memory_distance"] = dict(zip(all_fields_names, agg_avgs))

    chain_metrics = bb.evaluate_tracking_chain_length(graph._raw)
    start_years, end_years = get_start_end_years(
        graph, dataframes, record_schema, nb_of_years
    )

    data_dict["start_years"] = start_years
    data_dict["end_years"] = end_years
    data_dict["chain_lengths"] = chain_metrics.histogram

    avg_records_match, avg_trackers_match = get_avg_record_tracker_match(graph)
    data_dict["avg_records_match"] = avg_records_match
    data_dict["avg_trackers_match"] = avg_trackers_match

    histogram_records = aggregate_histograms(
        [
            resolving.histogram_record_matchs
            for resolving in graph.diagnostics.resolvings
        ]
    )

    histogram_trackers = aggregate_histograms(
        [
            resolving.histogram_tracker_matchs
            for resolving in graph.diagnostics.resolvings
        ]
    )

    data_dict["histogram_records"] = histogram_records
    data_dict["histogram_trackers"] = histogram_trackers

    all_dict["data"] = data_dict
    return all_dict


def compute_for_all_configs(
    configs: list[bb.TrackingConfig],
    dataframes: list[pl.DataFrame],
    record_schema: bb.RecordSchema,
    nb_samples: int,
    min_length: int,
    filepath: str,
    step_before_save: int,
):
    all_dicts: list[dict, Any] = []
    count = 1
    for i, config in enumerate(configs):
        print("execute config", i)
        all_dicts.append(
            build_tracking_summary(
                config, record_schema, dataframes, nb_samples, min_length, len(dataframes)
            )
        )

        if len(all_dicts) >= step_before_save:
            save_json(f"{filepath}_{count}.json", all_dicts)
            all_dicts = []
            count += 1
      
        gc.collect()

    if len(all_dicts) > 0:
        save_json(f"{filepath}_{count}.json", all_dicts)

In [None]:
filename = "configs_p7"
configs = configs[:]
compute_for_all_configs(
    configs=configs,
    dataframes=dataframes_35_45,
    record_schema=record_schema_base,
    nb_samples=1,
    min_length=7,
    filepath=filename,
    step_before_save=100,
)

In [None]:
def merge_results_files(filenames: list[str], out_filename: str) -> None:
    """
    Merges multiple JSON files into a single JSON file.

    Args:
        filenames (list[str]): List of input JSON filenames to merge.
        out_filename (str): Output filename for the merged JSON file.
    """
    all_data = []
    for filename in filenames:
        filepath = f"{json_folder_path}/{filename}"
        with open(filepath, "r") as f:
            data = json.load(f)
            all_data.extend(data)

    filepath = f"{json_folder_path}/{out_filename}"
    with open(filepath, "w") as f:
        json.dump(all_data, f, indent=2)

In [None]:
# merge_results_files(
#     [
#         "configs_p6.json",
#         "configs_p7.json",

#     ],
#     "configs_p8.json",
# )