From 0e0e36d2724a2b7f54ca65a3bd26f798bcb99a8e Mon Sep 17 00:00:00 2001 From: Meghana Sistla Date: Thu, 6 Mar 2025 03:08:53 +0000 Subject: [PATCH 1/2] Filtering data using merging datasets --- src/itp_interface/main/merge_dataset.py | 122 +++++++++++++++++++++--- 1 file changed, 111 insertions(+), 11 deletions(-) diff --git a/src/itp_interface/main/merge_dataset.py b/src/itp_interface/main/merge_dataset.py index 96d8aa4..754f35c 100644 --- a/src/itp_interface/main/merge_dataset.py +++ b/src/itp_interface/main/merge_dataset.py @@ -11,10 +11,93 @@ import typing import copy from itp_interface.tools.log_utils import setup_logger -from itp_interface.tools.training_data import TrainingData +from itp_interface.tools.training_data import TrainingData, TrainingDataFormat +def filter_training_data(training_data: TrainingData, max_distance_to_good: int = 10): -def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=None): + def _reconstruct_proof_tree(prev_state_id_map: typing.Dict[int, typing.Set[int]], + good_state_ids: typing.Set[int], + bad_state_ids: typing.Set[int], + distance_map: typing.Dict[int, int]): + distance = 1 + while True: + # The idea is at some point no new good state ids will be added + # This is level order traversal but in reverse + new_good_state_ids = set() + for end_state, start_states in prev_state_id_map.items(): + if end_state in good_state_ids: + for start_state in start_states: + if start_state not in good_state_ids: + # To avoid infinite loop we need to not add the start_state if it is already in good_state_ids + new_good_state_ids.add(start_state) + dist = distance_map.get(start_state, distance) + if dist >= distance: + distance_map[start_state] = distance + distance += 1 + if len(new_good_state_ids) == 0: + break + good_state_ids.update(new_good_state_ids) + # Now identify the states which are not good + for end_state in prev_state_id_map.keys(): + if end_state not in good_state_ids: + bad_state_ids.add(end_state) + + def _reconstruct_prev_state_id_map(training_datas: typing.List[TrainingDataFormat]) -> typing.Tuple[typing.Dict[int, int], int]: + prev_state_id_map = {} + done_state = None + for training_data in training_datas: + if training_data.addition_state_info is not None and len(training_data.addition_state_info) > 0: + start_state = training_data.addition_state_info.get("start_goal_id", None) + end_state = training_data.addition_state_info.get("end_goal_id", None) + if start_state is not None and end_state is not None: + prev_to_end_state = prev_state_id_map.get(end_state, set()) + prev_to_end_state.add(start_state) + prev_state_id_map[end_state] = prev_to_end_state + if done_state is None and training_data.addition_state_info.get("done", False): + done_state = end_state + return prev_state_id_map, done_state + + filtered_training_data : typing.List[TrainingDataFormat] = [] + proof_id_maps : typing.Dict[str, typing.List[TrainingDataFormat]] = {} + for idx in range(len(training_data)): + example = training_data[idx] + training_datas : typing.List[TrainingDataFormat] = proof_id_maps.get(example.proof_id, []) + training_datas.append(example) + proof_id_maps[example.proof_id] = training_datas + for proof_id, training_datas in proof_id_maps.items(): + prev_state_id_map, done_state = _reconstruct_prev_state_id_map(training_datas) + # Now we have the prev_state_id_map and done_state + # Every state from where we can reach done_state is a should have a good value + # Every other state should have a bad value + # First figure out all state ids from where we can reach done_state + good_state_ids = set() + good_state_ids.add(done_state) + bad_state_ids = set() + distance_map = {done_state: 0} + _reconstruct_proof_tree(prev_state_id_map, good_state_ids, bad_state_ids, distance_map) + # Now we have the good_state_ids and bad_state_ids + # Now annotate the training data with the value function + for training_data in training_datas: + if training_data.addition_state_info is not None and len(training_data.addition_state_info) > 0: + end_state_id = training_data.addition_state_info.get("end_goal_id", None) + if end_state_id is not None: + progress = training_data.addition_state_info.get("progress", "") + if end_state_id in good_state_ids and (progress == "StateChanged" or progress == "Done"): + distance = distance_map.get(end_state_id, max_distance_to_good) + distance = min(distance, max_distance_to_good) + progress = f"[GOOD] [{distance}] {progress}" + else: + progress = f"[BAD] {progress}" + training_data.addition_state_info["progress"] = progress + if "GOOD" in progress: + filtered_training_data.append(training_data) + elif training_data.addition_state_info is None: + filtered_training_data.append(training_data) + + return filtered_training_data + + +def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=None, should_filter_data=False): """ Merge datasets """ @@ -52,10 +135,23 @@ def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=No for td in tds: logger.info(f"Start loading {td.folder} ...") td.load() - logger.info(f"Loaded {td.folder}.") - logger.info(f"Start merging {td.folder} ...") - merged_td.merge(td) - logger.info(f"Merged {td.folder}") + + filtered_training_data_points : typing.List[TrainingDataFormat] = None + if should_filter_data and td.folder != datasets[0]: + # TODO: move to the right location + filtered_training_data_points = filter_training_data(td) + + if filtered_training_data_points is not None: + logger.info(f"Filtered training data points to {len(filtered_training_data_points)}") + logger.info(f"Start merging {td.folder} after filtering...") + for data in filtered_training_data_points: + merged_td.merge(data) + else: + logger.info(f"Loaded {td.folder}.") + logger.info(f"Start merging {td.folder} ...") + merged_td.merge(td) + logger.info(f"Merged {td.folder}") + logger.info("Finished merging datasets.") logger.info(f"Start saving merged dataset to {output} ...") merged_td.save() @@ -68,11 +164,15 @@ def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=No max_parallelism=max_parallelism, logger=logger ) - new_merged_td.load() - assert len(new_merged_td) == metadata.total_proof_step_cnt, "Merged dataset is not correct" - assert new_merged_td.meta.last_training_data == metadata.last_training_data, "Merged dataset is not correct" - assert new_merged_td.meta.last_proof_id == metadata.last_proof_id, "Merged dataset is not correct" - logger.info("Merged dataset is correct.") + if not should_filter_data: + new_merged_td.load() + assert len(new_merged_td) == metadata.total_proof_step_cnt, "Merged dataset is not correct" + assert new_merged_td.meta.last_training_data == metadata.last_training_data, "Merged dataset is not correct" + assert new_merged_td.meta.last_proof_id == metadata.last_proof_id, "Merged dataset is not correct" + logger.info("Merged dataset is correct.") + else: + logger.info("Filtered and merged data, skipping verification") + logger.info("Finished verifying the merged dataset.") pass From 4b0eaca487d3e090a382e00e2fc1daaa64a87f58 Mon Sep 17 00:00:00 2001 From: Amitayush Thakur Date: Wed, 12 Mar 2025 19:14:04 -0500 Subject: [PATCH 2/2] Fixed install script. --- README.md | 7 +++++-- pyproject.toml | 2 +- src/itp_interface/main/install.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a0e0f76..71e14f0 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,14 @@ Generic interface for hooking up to any Interactive Theorem Prover (ITP) and col pip install itp-interface ``` -2. Run the following command to prepare the REPL for Lean 4. (The default version is 4.7.0-rc2. You can change the version by setting the `LEAN_VERSION` environment variable. If no version is set, then 4.7.0-rc2 is used.) +2. Run the following command to prepare the REPL for Lean 4. The default version is 4.7.0-rc2. You can change the version by setting the `LEAN_VERSION` environment variable. If no version is set, then 4.7.0-rc2 is used. However, the itp-interface supports up to Lean 4.17.0. >NOTE: The Lean 4 version must match the version of the Lean 4 project you are working with. ```bash -export LEAN_VERSION="4.15.0" +export LEAN_VERSION="4.7.0-rc2" install-lean-repl +# ^^ Change the LEAN_VERSION to the version of Lean 4 you are working with. +# ^^^ Example: export LEAN_VERSION="4.15.0" to use Lean 4.15.0 +# itp-interface supports up to Lean 4.17.0 ``` 3. Run the following command to build the REPL for Lean 4. Make sure that `lean --version` returns the correct version before running the command below. If not then check if `$HOME/.elan/bin` is in your path. Recommended to run `source $HOME/.elan/env` before running the command below. diff --git a/pyproject.toml b/pyproject.toml index 1550254..8d03d97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ build-backend = "hatchling.build" [project] name = "itp_interface" -version = "1.1.7" +version = "1.1.8" authors = [ { name="Amitayush Thakur", email="amitayush@utexas.edu" }, ] diff --git a/src/itp_interface/main/install.py b/src/itp_interface/main/install.py index d52c5fd..2fc8b5b 100644 --- a/src/itp_interface/main/install.py +++ b/src/itp_interface/main/install.py @@ -64,7 +64,7 @@ def install_lean_repl(): elan_url = "https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh" os.system(f"curl -sSL {elan_url} | sh") print("[OK] .elan installed") - lean_repo = "leanprover-community/lean" + lean_repo = "leanprover/lean4" # Create a temporary script to run shell_code = f""" source $HOME/.elan/env