# AlphaGeometry 复现实验：几何定理自动证明

本 notebook 复现了 DeepMind 在 Nature 2024 发表的 AlphaGeometry 系统中的 **DDAR** 符号推理部分。

---

## 什么是 AlphaGeometry？

AlphaGeometry 是一个能够解决国际数学奥林匹克（IMO）级别几何证明题的系统，由两部分组成：

| 组件 | 功能 | 本实验是否包含 |
|------|------|---------------|
| DDAR 符号推理引擎 | 基于规则的演绎推理 | 包含 |
| 语言模型 | 建议辅助构造点 | 不包含（需要大量GPU）|

DDAR 单独使用时，对于不需要辅助构造的题目效果较好。复杂题目可能需要语言模型辅助添加辅助点才能证明。

---

## 参考资料

- 论文：Trinh et al. [Solving olympiad geometry without human demonstrations](https://www.nature.com/articles/s41586-023-06747-5). Nature 625, 476-482 (2024)
- 代码：https://github.com/google-deepmind/alphageometry

---
# 第一部分：环境配置

这部分完成代码下载和依赖安装，通常只需运行一次。

---

In [None]:
#@title 1.1 检查 Python 环境
#@markdown ### 这一步做什么？
#@markdown 1. **检查当前Python版本**：AlphaGeometry官方推荐Python 3.10
#@markdown 2. **自动安装Python3.10**（如果当前版本不匹配）
#@markdown
#@markdown ### 为什么需要这一步？
#@markdown - DeepMind官方代码基于Python 3.10开发
#@markdown - Colab默认可能是3.9或3.11，依赖包版本可能不兼容
#@markdown - 这一步确保环境一致性，避免后续错误
#@markdown

import sys
import os
import subprocess

def run_cmd(cmd):
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    return result.returncode == 0

current_ver = f"{sys.version_info.major}.{sys.version_info.minor}"
print(f"当前 Python 版本: {current_ver}")

if current_ver == "3.10":
    print("版本符合要求")
    PYTHON_PATH = sys.executable
else:
    print(f"正在安装 Python 3.10...")
    run_cmd("apt-get update -qq")
    run_cmd("apt-get install -y software-properties-common")
    run_cmd("add-apt-repository -y ppa:deadsnakes/ppa")
    run_cmd("apt-get update -qq")
    run_cmd("apt-get install -y python3.10 python3.10-distutils python3.10-venv")
    run_cmd("curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10")
    PYTHON_PATH = "/usr/bin/python3.10"
    print("安装完成")

os.environ["PYTHON_PATH"] = PYTHON_PATH
print(f"将使用: {PYTHON_PATH}")

当前 Python 版本: 3.12
正在安装 Python 3.10...
安装完成
将使用: /usr/bin/python3.10


In [None]:
#@title 1.2 下载代码并安装依赖
#@markdown ### 这一步做什么？
#@markdown 1. 从 DeepMind 官方仓库下载 AlphaGeometry 代码
#@markdown 2. 安装它需要的 Python 依赖包
#
# 为什么要“严格/宽松”两套安装？
# - 严格安装：按官方锁定版本（更像论文复现），但 Colab 环境更新快，有时会冲突
# - 宽松安装：更容易在 Colab 装成功，但版本不一定完全一致（课堂演示通常够用）
import shutil
import os

REPO_DIR = "alphageometry"
PYTHON_PATH = os.environ.get("PYTHON_PATH", sys.executable)

if os.path.isdir(REPO_DIR):
    print(f"清理旧文件...")
    shutil.rmtree(REPO_DIR)

print("克隆 AlphaGeometry 仓库...")
os.system(f"git clone https://github.com/google-deepmind/alphageometry.git {REPO_DIR}")

print("\n安装依赖包...")
os.system(f"cd {REPO_DIR} && {PYTHON_PATH} -m pip install -U pip setuptools wheel -q")

ret = os.system(f"cd {REPO_DIR} && {PYTHON_PATH} -m pip install --require-hashes -r requirements.txt -q 2>/dev/null")
if ret != 0:
    print("使用宽松模式安装...")
    os.system(f"cd {REPO_DIR} && {PYTHON_PATH} -m pip install -r requirements.in -q")

numericals_file = os.path.join(REPO_DIR, "numericals.py")
if os.path.exists(numericals_file):
    with open(numericals_file, "r") as f:
        content = f.read()
    if "TkAgg" in content:
        content = content.replace("matplotlib.use('TkAgg')", "matplotlib.use('Agg')")
        content = content.replace('matplotlib.use("TkAgg")', 'matplotlib.use("Agg")')
        with open(numericals_file, "w") as f:
            f.write(content)

print("\n安装完成")

克隆 AlphaGeometry 仓库...

安装依赖包...
使用宽松模式安装...

安装完成


In [None]:
#@title 1.3 验证安装
#@markdown ### 这一步做什么？
#@markdown 验证重要文件是否安装就绪。
REPO_DIR = "alphageometry"

required_files = [
    ("defs.txt", "几何构造定义"),
    ("rules.txt", "推理规则"),
    ("ddar.py", "推理引擎"),
    ("problem.py", "问题解析"),
    ("graph.py", "关系图结构"),
]

print("检查核心文件：\n")
all_ok = True
for fname, desc in required_files:
    path = os.path.join(REPO_DIR, fname)
    exists = os.path.exists(path)
    status = "OK" if exists else "缺失"
    print(f"  [{status}] {fname} - {desc}")
    if not exists:
        all_ok = False

if all_ok:
    print("\n所有文件就绪")
else:
    print("\n有文件缺失，请重新运行 1.2")

检查核心文件：

  [OK] defs.txt - 几何构造定义
  [OK] rules.txt - 推理规则
  [OK] ddar.py - 推理引擎
  [OK] problem.py - 问题解析
  [OK] graph.py - 关系图结构

所有文件就绪


---
# 第二部分：数据集配置

选择要测试的几何题目来源。提供四种选项：

| 选项 | 说明 | 适用场景 |
|------|------|----------|
| 内置示例 | 5道经典几何定理 | 快速验证、教学演示 |
| 自定义上传 | 上传自己的题目文件 | 测试特定题目 |
| AG-30 | 论文官方测试集（30题）| 复现论文结果 |
| IMO-2000-2022 | 更大的IMO题目集合 | 深入研究 |

---

### 参数说明：DATASET_SOURCE

**如何选择：**
- 第一次使用或教学演示 - 选择「内置示例」
- 想测试自己编写的题目 - 选择「自定义上传」
- 想复现论文结果 - 选择「AG-30」
- 想做更全面的测试 - 选择「IMO-2000-2022」

**注意事项：**
- 自定义上传需要按照特定格式编写题目（见附录）
- AG-30 和 IMO-2000-2022 需要联网下载

In [None]:
#@title 2.1 选择数据集

DATASET_SOURCE = "AG-30"  #@param ["内置示例", "自定义上传", "AG-30", "IMO-2000-2022"]

import os
import urllib.request
from pathlib import Path

REPO_DIR = "alphageometry"
DATA_DIR = Path(REPO_DIR)

BUILTIN_PROBLEMS = """orthocenter
a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c

midpoint_theorem
a b c = triangle a b c; m = midpoint m a b; n = midpoint n a c ? para m n b c

isoceles_base
a b c = iso_triangle a b c; m = midpoint m b c ? perp a m b c

incenter_equidistant
a b c = triangle a b c; i = incenter i a b c; d = foot d i b c; e = foot e i c a ? cong i d i e

circumcenter_equidistant
a b c = triangle a b c; o = circumcenter o a b c ? cong o a o b
"""

DATASET_URLS = {
    "AG-30": "https://raw.githubusercontent.com/google-deepmind/alphageometry/main/imo_ag_30.txt",
    "IMO-2000-2022": "https://raw.githubusercontent.com/google-deepmind/alphageometry/main/imo_test_data.txt",
}

print("=" * 55)
print(f"数据集配置: {DATASET_SOURCE}")
print("=" * 55)

dataset_path = None

if DATASET_SOURCE == "内置示例":
    dataset_path = DATA_DIR / "builtin_examples.txt"
    dataset_path.write_text(BUILTIN_PROBLEMS, encoding='utf-8')
    print("\n使用内置示例（5道经典几何定理）：")
    print("  1. orthocenter - 垂心性质")
    print("  2. midpoint_theorem - 中点定理")
    print("  3. isoceles_base - 等腰三角形底边性质")
    print("  4. incenter_equidistant - 内心到各边距离相等")
    print("  5. circumcenter_equidistant - 外心到各顶点距离相等")

elif DATASET_SOURCE == "自定义上传":
    print("\n请上传题目文件（.txt格式）")
    print("格式要求：每题两行，第一行题目名，第二行题目定义\n")
    try:
        from google.colab import files
        uploaded = files.upload()
        if uploaded:
            filename = list(uploaded.keys())[0]
            dataset_path = DATA_DIR / filename
            with open(dataset_path, 'wb') as f:
                f.write(uploaded[filename])
            print(f"\n已上传: {filename}")
        else:
            print("\n未检测到上传文件，将使用内置示例")
            dataset_path = DATA_DIR / "builtin_examples.txt"
            dataset_path.write_text(BUILTIN_PROBLEMS, encoding='utf-8')
    except ImportError:
        print("\n非Colab环境，使用内置示例")
        dataset_path = DATA_DIR / "builtin_examples.txt"
        dataset_path.write_text(BUILTIN_PROBLEMS, encoding='utf-8')
    except Exception as e:
        print(f"\n上传出错: {e}，使用内置示例")
        dataset_path = DATA_DIR / "builtin_examples.txt"
        dataset_path.write_text(BUILTIN_PROBLEMS, encoding='utf-8')

elif DATASET_SOURCE in DATASET_URLS:
    url = DATASET_URLS[DATASET_SOURCE]
    filename = f"{DATASET_SOURCE.lower().replace('-', '_')}.txt"
    dataset_path = DATA_DIR / filename

    print(f"\n下载 {DATASET_SOURCE} 数据集...")
    try:
        urllib.request.urlretrieve(url, dataset_path)
        print("下载成功")
    except Exception as e:
        print(f"下载失败: {e}")
        local_file = DATA_DIR / "imo_ag_30.txt"
        if local_file.exists():
            dataset_path = local_file
            print("使用仓库自带的 imo_ag_30.txt")
        else:
            print("回退到内置示例")
            dataset_path = DATA_DIR / "builtin_examples.txt"
            dataset_path.write_text(BUILTIN_PROBLEMS, encoding='utf-8')

if dataset_path and dataset_path.exists():
    with open(dataset_path, 'r', encoding='utf-8') as f:
        lines = [l.strip() for l in f if l.strip() and not l.startswith('#')]

    problems = []
    i = 0
    while i < len(lines):
        if i + 1 < len(lines) and '=' in lines[i + 1]:
            problems.append((lines[i], lines[i + 1]))
            i += 2
        else:
            i += 1

    print(f"\n共加载 {len(problems)} 道题目")

    if len(problems) > 0:
        print("\n前3题预览：")
        print("-" * 55)
        for idx, (name, defn) in enumerate(problems[:3]):
            print(f"\n[{idx}] {name}")
            if len(defn) > 60:
                print(f"    {defn[:60]}...")
            else:
                print(f"    {defn}")
        print("-" * 55)

    os.environ["DATASET_PATH"] = str(dataset_path)
    os.environ["DATASET_NAME"] = DATASET_SOURCE
    print("\n数据集配置完成")
else:
    print("\n数据集加载失败")

数据集配置: AG-30

下载 AG-30 数据集...
下载成功

共加载 30 道题目

前3题预览：
-------------------------------------------------------

[0] translated_imo_2000_p1
    a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 ...

[1] translated_imo_2000_p6
    a b c = triangle a b c; h = orthocenter h a b c; t1 t2 t3 i ...

[2] translated_imo_2002_p2a
    b c = segment b c; o = midpoint o b c; a = on_circle a o b; ...
-------------------------------------------------------

数据集配置完成


---
# 第三部分：运行推理

## DDAR 工作原理简述

DDAR 从初始条件出发，反复应用推理规则推导新的几何关系：

```
初始条件 -> 应用规则 -> 新关系 -> 应用规则 -> ... -> 目标关系
```

推理分两个阶段：
- **DD（演绎数据库）**：应用几何定理规则
- **AR（代数推理）**：处理角度和线段的数值方程

---

In [None]:
#@title 3.1 初始化推理引擎
#@markdown 这一步相当于“加载 AlphaGeometry 的推理大脑”。

import sys
import os

REPO_DIR = "alphageometry"
sys.path.insert(0, os.path.join(os.getcwd(), REPO_DIR))

print("加载推理引擎...\n")

try:
    import ddar
    import problem as pr
    import graph as gh

    defs_path = os.path.join(REPO_DIR, "defs.txt")
    rules_path = os.path.join(REPO_DIR, "rules.txt")

    DEFINITIONS = pr.Definition.from_txt_file(defs_path, to_dict=True)
    RULES = pr.Theorem.from_txt_file(rules_path, to_dict=True)

    print(f"几何定义: {len(DEFINITIONS)} 个")
    print(f"推理规则: {len(RULES)} 条")

    globals()['DEFINITIONS'] = DEFINITIONS
    globals()['RULES'] = RULES
    globals()['pr'] = pr
    globals()['gh'] = gh
    globals()['ddar'] = ddar

    print("\n初始化完成")

except Exception as e:
    print(f"初始化失败: {e}")
    print("请确保已运行第一部分的安装步骤")

加载推理引擎...

几何定义: 68 个
推理规则: 43 条

初始化完成


### 参数说明：单题测试

**PROBLEM_INDEX（题目编号）**
- 从 0 开始计数，0 表示第一题
- 滑动选择要测试的题目

**MAX_LEVEL（推理深度）**
- 控制推理引擎最多迭代多少轮
- 值越大，搜索越深入，但耗时越长
- 推荐值：
  - 简单题目：500-1000
  - 中等题目：1000-2000
  - 复杂题目：2000-5000
- 如果某题失败，可以尝试增大此值

In [None]:
#@title 3.2 单题测试

PROBLEM_INDEX = 0  #@param {type:"slider", min:0, max:50, step:1}
MAX_LEVEL = 1000  #@param {type:"slider", min:100, max:5000, step:100}

import time

DEFINITIONS = globals().get('DEFINITIONS')
RULES = globals().get('RULES')
pr = globals().get('pr')
gh = globals().get('gh')
ddar = globals().get('ddar')

if not all([DEFINITIONS, RULES, pr, gh, ddar]):
    print("请先运行 3.1 初始化推理引擎")
else:
    dataset_path = os.environ.get("DATASET_PATH")
    if not dataset_path or not os.path.exists(dataset_path):
        print("请先运行 2.1 配置数据集")
    else:
        with open(dataset_path, 'r', encoding='utf-8') as f:
            lines = [l.strip() for l in f if l.strip() and not l.startswith('#')]

        problems = []
        i = 0
        while i < len(lines):
            if i + 1 < len(lines) and '=' in lines[i + 1]:
                problems.append((lines[i], lines[i + 1]))
                i += 2
            else:
                i += 1

        if PROBLEM_INDEX >= len(problems):
            print(f"题号 {PROBLEM_INDEX} 超出范围（共 {len(problems)} 题），使用第 0 题")
            PROBLEM_INDEX = 0

        name, text = problems[PROBLEM_INDEX]

        print("=" * 60)
        print(f"题目 [{PROBLEM_INDEX}]: {name}")
        print("=" * 60)
        print(f"\n{text}\n")
        print("=" * 60)

        try:
            print("\n[步骤1] 解析题目...")
            p = pr.Problem.from_txt(text)
            has_goal = hasattr(p, 'goal') and p.goal is not None
            print(f"  证明目标: {p.goal if has_goal else '无（探索模式）'}")

            print("\n[步骤2] 构建几何关系图...")
            g, _ = gh.Graph.build_problem(p, DEFINITIONS)
            initial = len(g.cache)
            print(f"  初始关系数: {initial}")

            print(f"\n[步骤3] 执行DDAR推理（深度={MAX_LEVEL}）...")
            start = time.time()
            success = ddar.solve(g, RULES, p, max_level=MAX_LEVEL)
            elapsed = time.time() - start
            final = len(g.cache)

            print(f"  耗时: {elapsed:.2f} 秒")
            print(f"  最终关系数: {final}（新增 {final - initial}）")

            print(f"\n推导出的关系（部分）：")
            print("-" * 50)
            rels = list(g.cache.keys())
            for i, rel in enumerate(rels[:12]):
                print(f"  {i+1:2d}. {rel}")
            if len(rels) > 12:
                print(f"  ... 共 {len(rels)} 个关系")
            print("-" * 50)

            print("\n" + "=" * 60)
            if success:
                print("结果: 证明成功")
            else:
                print("结果: 未能证明")
                print("\n可能的原因：")
                print("  1. 题目需要辅助构造（超出DDAR能力）")
                print("  2. 推理深度不足，可尝试增大 MAX_LEVEL")
            print("=" * 60)

        except Exception as e:
            print(f"\n运行出错: {e}")
            import traceback
            traceback.print_exc()

题目 [0]: translated_imo_2000_p1

a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m = on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b, on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q


[步骤1] 解析题目...
  证明目标: <problem.Construction object at 0x7d2ab060af00>

[步骤2] 构建几何关系图...
  初始关系数: 20

[步骤3] 执行DDAR推理（深度=1000）...
  耗时: 20.67 秒
  最终关系数: 1375（新增 1355）

推导出的关系（部分）：
--------------------------------------------------
   1. ('perp', 'a', 'b', 'a', 'c')
   2. ('eqangle', 'a', 'b', 'a', 'c', 'a', 'c', 'a', 'b')
   3. ('para', 'a', 'c', 'b', 'd')
   4. ('cong', 'a', 'c', 'c', 'e')
   5. ('cong', 'b', 'd', 'd', 'e')
   6. ('cong', 'a', 'c', 'c', 'f')
   7. ('cong', 'b', 'd', 'd', 'f')
   8. ('para', 'a', 'b', 'e', 'g')
   9. ('cong', 'a', 'c', 'c', 'g')
  10. ('cyclic', 'a', 'e', 'f', 'g')
  11. ('para', 'a',

### 参数说明：批量测试

**TEST_RANGE（测试范围）**
- 选择测试多少道题目
- 建议先用「前5题」快速验证环境
- 完整测试可能需要较长时间

**TIMEOUT（超时时间）**
- 单题最长允许运行的时间（秒）
- 超时后会跳过该题继续下一题
- 推荐值：60秒（大部分题目足够）
- 如果超时率高，可增大到 120-180 秒

**MAX_LEVEL（推理深度）**
- 同单题测试，控制推理的深度
- 批量测试时建议 1000-2000

In [None]:
#@title 3.3 批量测试

TEST_RANGE = "全部"  #@param ["前5题", "前10题", "前20题", "全部"]
TIMEOUT = 60  #@param {type:"slider", min:10, max:300, step:10}
MAX_LEVEL = 1000  #@param {type:"slider", min:100, max:5000, step:100}

import time
import signal
from contextlib import contextmanager

class TimeoutError(Exception):
    pass

@contextmanager
def time_limit(seconds):
    def handler(signum, frame):
        raise TimeoutError()
    old = signal.signal(signal.SIGALRM, handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old)

DEFINITIONS = globals().get('DEFINITIONS')
RULES = globals().get('RULES')
pr = globals().get('pr')
gh = globals().get('gh')
ddar = globals().get('ddar')

if not all([DEFINITIONS, RULES, pr, gh, ddar]):
    print("请先运行 3.1 初始化推理引擎")
else:
    dataset_path = os.environ.get("DATASET_PATH")
    dataset_name = os.environ.get("DATASET_NAME", "数据集")

    if not dataset_path or not os.path.exists(dataset_path):
        print("请先运行 2.1 配置数据集")
    else:
        with open(dataset_path, 'r', encoding='utf-8') as f:
            lines = [l.strip() for l in f if l.strip() and not l.startswith('#')]

        all_problems = []
        i = 0
        while i < len(lines):
            if i + 1 < len(lines) and '=' in lines[i + 1]:
                all_problems.append((lines[i], lines[i + 1]))
                i += 2
            else:
                i += 1

        range_map = {"前5题": 5, "前10题": 10, "前20题": 20, "全部": len(all_problems)}
        count = min(range_map.get(TEST_RANGE, 5), len(all_problems))
        problems = all_problems[:count]

        print("=" * 60)
        print("批量测试")
        print("=" * 60)
        print(f"数据集: {dataset_name}")
        print(f"测试范围: {count} / {len(all_problems)} 题")
        print(f"超时设置: {TIMEOUT} 秒")
        print(f"推理深度: {MAX_LEVEL}")
        print("=" * 60)

        results = {'success': [], 'fail': [], 'timeout': [], 'error': []}
        start_all = time.time()

        for idx, (name, defn) in enumerate(problems):
            status = ""
            try:
                p = pr.Problem.from_txt(defn)
                start = time.time()

                try:
                    with time_limit(TIMEOUT):
                        g, _ = gh.Graph.build_problem(p, DEFINITIONS)
                        success = ddar.solve(g, RULES, p, max_level=MAX_LEVEL)
                        elapsed = time.time() - start

                        if success:
                            results['success'].append((name, elapsed))
                            status = f"成功 ({elapsed:.1f}s)"
                        else:
                            results['fail'].append((name, elapsed))
                            status = f"失败 ({elapsed:.1f}s)"

                except TimeoutError:
                    results['timeout'].append(name)
                    status = f"超时 (>{TIMEOUT}s)"

            except Exception as e:
                results['error'].append((name, str(e)))
                status = f"错误"

            print(f"[{idx+1:2d}/{count}] {name[:30]:<30} -> {status}")

        total_time = time.time() - start_all

        print("\n" + "=" * 60)
        print("测试结果")
        print("=" * 60)

        n_success = len(results['success'])
        n_fail = len(results['fail'])
        n_timeout = len(results['timeout'])
        n_error = len(results['error'])
        rate = n_success / count * 100 if count > 0 else 0

        print(f"\n总计测试: {count} 题")
        print(f"  成功: {n_success} ({rate:.1f}%)")
        print(f"  失败: {n_fail}")
        print(f"  超时: {n_timeout}")
        print(f"  错误: {n_error}")
        print(f"\n总耗时: {total_time:.1f} 秒 ({total_time/60:.1f} 分钟)")

        if results['success']:
            print(f"\n成功证明的题目：")
            for name, t in results['success']:
                print(f"  + {name} ({t:.1f}s)")

        if results['fail']:
            print(f"\n未能证明的题目：")
            for name, t in results['fail'][:10]:
                print(f"  - {name}")
            if len(results['fail']) > 10:
                print(f"  ... 还有 {len(results['fail'])-10} 题")

        print("\n" + "=" * 60)

批量测试
数据集: AG-30
测试范围: 30 / 30 题
超时设置: 60 秒
推理深度: 1000
[ 1/30] translated_imo_2000_p1         -> 成功 (20.8s)
[ 2/30] translated_imo_2000_p6         -> 超时 (>60s)
[ 3/30] translated_imo_2002_p2a        -> 成功 (16.4s)
[ 4/30] translated_imo_2002_p2b        -> 成功 (3.0s)
[ 5/30] translated_imo_2003_p4         -> 成功 (37.6s)
[ 6/30] translated_imo_2004_p1         -> 成功 (13.5s)
[ 7/30] translated_imo_2004_p5         -> 成功 (1.4s)
[ 8/30] translated_imo_2005_p5         -> 错误
[ 9/30] translated_imo_2007_p4         -> 成功 (7.3s)
[10/30] translated_imo_2008_p1a        -> 错误
[11/30] translated_imo_2008_p1b        -> 错误
[12/30] translated_imo_2008_p6         -> 超时 (>60s)
[13/30] translated_imo_2009_p2         -> 成功 (5.8s)
[14/30] translated_imo_2010_p2         -> 成功 (10.9s)
[15/30] translated_imo_2010_p4         -> 成功 (12.4s)
[16/30] translated_imo_2011_p6         -> 超时 (>60s)
[17/30] translated_imo_2012_p1         -> 成功 (25.9s)
[18/30] translated_imo_2012_p5         -> 成功 (1.0s)
[19/30] translated_imo_2

---
# 第四部分：可视化分析

将几何问题和推理过程以图形化方式展示，帮助理解 DDAR 的工作过程。

---

### 参数说明：可视化

**PROBLEM_INDEX（题目编号）**
- 选择要可视化的题目
- 从 0 开始计数，0 表示第一题
- 如果超出数据集范围会自动使用第 0 题

**SHOW_RELATIONS（显示关系数）**
- 控制右上区域显示多少条推导出的关系
- 值越大显示越详细，但可能超出显示区域
- 建议值：10-20

**VIS_TIMEOUT（超时时间）**
- 控制构建图和运行推理的最长等待时间（秒）
- 超时后会停止计算，但仍显示已有信息
- 建议值：30-60 秒
- 如果经常超时，可以增大此值

**图形说明：**

| 区域 | 内容 | 说明 |
|------|------|------|
| 左上 | 几何图形 | 显示题目中的点和连线。注意：坐标是示意性的近似值，不是精确位置，仅用于理解题目结构 |
| 右上 | 推导关系 | DDAR 推理过程中发现的几何关系，如 `perp`（垂直）、`para`（平行）、`cong`（相等）等 |
| 左下 | 统计图表 | 三个条形：Initial（初始关系数）、Final（最终关系数）、New（新推导的关系数） |
| 右下 | 证明结果 | 显示题目名称、证明目标、点的数量、超时设置，以及最终状态 |

**结果框颜色含义：**

| 颜色 | 状态 | 说明 |
|------|------|------|
| 绿色边框 + 浅绿背景 | PROVED | 证明成功 |
| 橙色边框 + 浅橙背景 | BUILD TIMEOUT / SOLVE TIMEOUT | 构建或推理超时，可尝试增大 VIS_TIMEOUT |
| 红色边框 + 浅红背景 | NOT PROVED / SOLVE ERROR | 未能证明或推理出错 |

**注意事项：**
- 某些复杂 IMO 题目可能需要较长时间，建议先用内置示例测试
- 超时不代表题目无法证明，可能只是需要更多时间或更大的推理深度
- 几何图形仅为示意，实际几何关系以右上角的推导关系列表为准
- 如果显示 "No relations"，说明构建阶段就失败或超时了

In [1]:
#@title 4.1 可视化几何图形与推理过程

PROBLEM_INDEX = 9  #@param {type:"slider", min:0, max:50, step:1}
SHOW_RELATIONS = 15  #@param {type:"slider", min:5, max:30, step:5}
VIS_TIMEOUT = 30  #@param {type:"slider", min:5, max:120, step:5}

%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon, FancyBboxPatch
from matplotlib.gridspec import GridSpec
import matplotlib.patheffects as path_effects
import signal
from contextlib import contextmanager

# 超时控制
class TimeoutError(Exception):
    pass

@contextmanager
def time_limit(seconds):
    def handler(signum, frame):
        raise TimeoutError()
    old = signal.signal(signal.SIGALRM, handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old)

DEFINITIONS = globals().get('DEFINITIONS')
RULES = globals().get('RULES')
pr = globals().get('pr')
gh = globals().get('gh')
ddar = globals().get('ddar')

if not all([DEFINITIONS, RULES, pr, gh, ddar]):
    print("请先运行 3.1 初始化推理引擎")
else:
    dataset_path = os.environ.get("DATASET_PATH")
    if not dataset_path or not os.path.exists(dataset_path):
        print("请先运行 2.1 配置数据集")
    else:
        with open(dataset_path, 'r', encoding='utf-8') as f:
            lines = [l.strip() for l in f if l.strip() and not l.startswith('#')]

        problems = []
        i = 0
        while i < len(lines):
            if i + 1 < len(lines) and '=' in lines[i + 1]:
                problems.append((lines[i], lines[i + 1]))
                i += 2
            else:
                i += 1

        if PROBLEM_INDEX >= len(problems):
            PROBLEM_INDEX = 0

        name, text = problems[PROBLEM_INDEX]
        print(f"Problem [{PROBLEM_INDEX}]: {name}")
        print(f"Definition: {text[:70]}{'...' if len(text)>70 else ''}\n")

        # 初始化变量
        g = None
        initial_facts = 0
        final_facts = 0
        success = False
        build_ok = False
        solve_status = "NOT RUN"

        try:
            p = pr.Problem.from_txt(text)

            # 提取点（在构建图之前就可以做）
            points = set()
            for clause in p.clauses:
                if hasattr(clause, 'points'):
                    points.update(clause.points)

            point_names = set()
            for pt in points:
                pt_name = pt.name if hasattr(pt, 'name') else str(pt)
                point_names.add(pt_name)

            # 构建图（带超时）
            print(f"Building graph (timeout: {VIS_TIMEOUT}s)...")
            try:
                with time_limit(VIS_TIMEOUT):
                    try:
                        g, _ = gh.Graph.build_problem(p, DEFINITIONS)
                        build_ok = True
                    except Exception as build_err:
                        print(f"  Warning: {type(build_err).__name__}, trying alternative...")
                        g = gh.Graph()
                        for clause in p.clauses:
                            try:
                                g.add_clause(clause, DEFINITIONS)
                            except:
                                pass
                        build_ok = True
            except TimeoutError:
                print(f"  Build timed out!")
                solve_status = "BUILD TIMEOUT"

            if build_ok and g is not None:
                initial_facts = len(g.cache)

                # 运行推理（带超时）
                print(f"Running DDAR (timeout: {VIS_TIMEOUT}s)...")
                try:
                    with time_limit(VIS_TIMEOUT):
                        success = ddar.solve(g, RULES, p, max_level=1000)
                        solve_status = "PROVED" if success else "NOT PROVED"
                except TimeoutError:
                    print(f"  Solve timed out!")
                    solve_status = "SOLVE TIMEOUT"
                except Exception as solve_err:
                    print(f"  Solve error: {type(solve_err).__name__}")
                    solve_status = "SOLVE ERROR"

                final_facts = len(g.cache)

            print(f"Points: {len(point_names)}, Initial: {initial_facts}, Final: {final_facts}")
            print(f"Status: {solve_status}\n")

            # ========== 开始绑制图形（无论成功与否都绑制）==========
            fig = plt.figure(figsize=(14, 10))
            gs = GridSpec(2, 2, figure=fig, height_ratios=[1.2, 1], hspace=0.25, wspace=0.2)

            ax1 = fig.add_subplot(gs[0, 0])
            ax2 = fig.add_subplot(gs[0, 1])
            ax3 = fig.add_subplot(gs[1, 0])
            ax4 = fig.add_subplot(gs[1, 1])

            # === 左上：几何图形 ===
            ax1.set_aspect('equal')
            ax1.set_title(f'Geometry: {name}', fontsize=12, fontweight='bold', pad=10)
            ax1.grid(True, alpha=0.3, linestyle='--', color='gray')
            ax1.set_facecolor('#fafafa')

            np.random.seed(42)
            base_coords = {
                'a': (0, 0), 'b': (5, 0), 'c': (2.5, 4.3),
                'o': (2.5, 1.5), 'i': (2.2, 1.2), 'm': (2.5, 0),
                'n': (1.25, 2.15), 'd': (2.5, 2.5), 'e': (1.5, 1.8),
                'f': (3.5, 1.8), 'g': (1.0, 0.5), 'h': (4.0, 0.5),
            }

            coords = {}
            for pt_name in point_names:
                if pt_name in base_coords:
                    coords[pt_name] = base_coords[pt_name]
                else:
                    angle = np.random.uniform(0, 2*np.pi)
                    r = np.random.uniform(1, 2.5)
                    coords[pt_name] = (2.5 + r*np.cos(angle), 2 + r*np.sin(angle))

            if all(pt in point_names for pt in ['a', 'b', 'c']):
                tri = Polygon(
                    [coords['a'], coords['b'], coords['c']],
                    fill=True, alpha=0.15, facecolor='#3498db',
                    edgecolor='#2c3e50', linewidth=2.5
                )
                ax1.add_patch(tri)

            point_colors = {
                'a': '#e74c3c', 'b': '#27ae60', 'c': '#3498db',
                'o': '#9b59b6', 'i': '#f39c12', 'm': '#1abc9c', 'n': '#1abc9c',
                'd': '#e67e22', 'e': '#16a085', 'f': '#8e44ad', 'g': '#2c3e50', 'h': '#c0392b'
            }

            for pt_name, (x, y) in coords.items():
                color = point_colors.get(pt_name, '#34495e')
                ax1.plot(x, y, 'o', markersize=12, color=color,
                        markeredgecolor='white', markeredgewidth=2, zorder=10)
                txt = ax1.text(x + 0.25, y + 0.25, pt_name.upper(), fontsize=13,
                              fontweight='bold', color=color)
                txt.set_path_effects([path_effects.withStroke(linewidth=3, foreground='white')])

            line_pairs = [('a', 'b'), ('b', 'c'), ('a', 'c'),
                         ('a', 'm'), ('m', 'b'), ('a', 'n'), ('n', 'c'), ('m', 'n'),
                         ('a', 'd'), ('b', 'd'), ('c', 'd'), ('a', 'o'), ('b', 'o'), ('c', 'o'),
                         ('i', 'd'), ('i', 'e'), ('a', 'i'), ('b', 'i'), ('c', 'i')]
            for p1, p2 in line_pairs:
                if p1 in coords and p2 in coords:
                    ax1.plot([coords[p1][0], coords[p2][0]],
                            [coords[p1][1], coords[p2][1]],
                            'k-', alpha=0.4, linewidth=1.5, zorder=1)

            xs = [c[0] for c in coords.values()]
            ys = [c[1] for c in coords.values()]
            if xs and ys:
                margin = 1.5
                ax1.set_xlim(min(xs)-margin, max(xs)+margin)
                ax1.set_ylim(min(ys)-margin, max(ys)+margin)

            # === 右上：推导关系 ===
            ax2.axis('off')
            ax2.set_title('Derived Relations', fontsize=12, fontweight='bold', pad=10)

            if g is not None and hasattr(g, 'cache') and len(g.cache) > 0:
                rels = list(g.cache.keys())[:SHOW_RELATIONS]
                rel_text = ""
                for idx, rel in enumerate(rels, 1):
                    rel_text += f"{idx:2d}. {rel}\n"
                if len(g.cache) > SHOW_RELATIONS:
                    rel_text += f"\n... +{len(g.cache) - SHOW_RELATIONS} more"
            else:
                rel_text = "(No relations - build failed or timed out)"

            ax2.text(0.02, 0.98, rel_text, transform=ax2.transAxes,
                    fontsize=9, verticalalignment='top', family='monospace',
                    bbox=dict(boxstyle='round,pad=0.5', facecolor='#ecf0f1',
                             edgecolor='#bdc3c7', alpha=0.9))

            # === 左下：统计条形图 ===
            ax3.set_title('Statistics', fontsize=12, fontweight='bold', pad=10)
            categories = ['Initial', 'Final', 'New']
            values = [initial_facts, final_facts, final_facts - initial_facts]
            colors_bar = ['#3498db', '#2ecc71', '#f39c12']

            bars = ax3.barh(categories, values, color=colors_bar, height=0.6)
            ax3.set_xlabel('Count', fontsize=10)
            max_val = max(values) if max(values) > 0 else 10
            ax3.set_xlim(0, max_val * 1.3)

            for bar, val in zip(bars, values):
                ax3.text(val + max_val*0.03, bar.get_y() + bar.get_height()/2,
                        str(val), va='center', fontsize=11, fontweight='bold')

            ax3.spines['top'].set_visible(False)
            ax3.spines['right'].set_visible(False)

            # === 右下：结果信息 ===
            ax4.axis('off')
            ax4.set_title('Result', fontsize=12, fontweight='bold', pad=10)

            goal_text = str(p.goal) if hasattr(p, 'goal') and p.goal else 'None'

            # 根据状态选择颜色
            if solve_status == "PROVED":
                result_color = '#27ae60'
                bg_color = '#f0fff0'
            elif "TIMEOUT" in solve_status:
                result_color = '#f39c12'
                bg_color = '#fffbf0'
            else:
                result_color = '#e74c3c'
                bg_color = '#fff0f0'

            bbox_patch = FancyBboxPatch(
                (0.05, 0.1), 0.9, 0.8,
                boxstyle="round,pad=0.02,rounding_size=0.03",
                facecolor=bg_color,
                edgecolor=result_color,
                linewidth=2.5,
                transform=ax4.transAxes,
                clip_on=False
            )
            ax4.add_patch(bbox_patch)

            info_text = f"Problem: {name}\n\nGoal: {goal_text}\n\nPoints: {len(point_names)}\nTimeout: {VIS_TIMEOUT}s\n\nStatus: {solve_status}"

            ax4.text(0.5, 0.5, info_text, transform=ax4.transAxes,
                    fontsize=11, verticalalignment='center', horizontalalignment='center')

            plt.show()

        except Exception as e:
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()

请先运行 3.1 初始化推理引擎


---
# 第五部分：自定义题目测试

可以直接输入几何题目定义，测试 DDAR 能否证明。

---

### 参数说明：自定义题目

**CUSTOM_PROBLEM（题目定义）**
- 按照 AlphaGeometry 格式输入题目
- 格式：`点 = 构造; 更多构造 ? 目标`

**格式示例：**
```
a b c = triangle a b c; m = midpoint m a b ? coll a m b
```

**常用构造函数：**
- `triangle a b c` - 三角形
- `midpoint m a b` - 中点
- `on_line p a b` - 点在直线上
- `on_tline p a b c` - 点在过a垂直于bc的线上
- `incenter i a b c` - 内心
- `circumcenter o a b c` - 外心

**常用目标：**
- `perp a b c d` - AB 垂直于 CD
- `para a b c d` - AB 平行于 CD
- `cong a b c d` - AB = CD
- `coll a b c` - 三点共线

In [None]:
#@title 5.1 测试自定义题目

CUSTOM_PROBLEM = "a b c = triangle a b c; m = midpoint m a b; n = midpoint n a c ? para m n b c"  #@param {type:"string"}
CUSTOM_MAX_LEVEL = 1000  #@param {type:"slider", min:100, max:5000, step:100}

import time

DEFINITIONS = globals().get('DEFINITIONS')
RULES = globals().get('RULES')
pr = globals().get('pr')
gh = globals().get('gh')
ddar = globals().get('ddar')

if not all([DEFINITIONS, RULES, pr, gh, ddar]):
    print("请先运行 3.1 初始化推理引擎")
elif not CUSTOM_PROBLEM.strip():
    print("请输入题目定义")
else:
    print("=" * 60)
    print("自定义题目测试")
    print("=" * 60)
    print(f"\n输入: {CUSTOM_PROBLEM}\n")
    print("=" * 60)

    try:
        p = pr.Problem.from_txt(CUSTOM_PROBLEM)
        has_goal = hasattr(p, 'goal') and p.goal is not None

        print(f"\n解析成功")
        print(f"证明目标: {p.goal if has_goal else '无'}")

        g, _ = gh.Graph.build_problem(p, DEFINITIONS)
        initial = len(g.cache)
        print(f"初始关系数: {initial}")

        print(f"\n执行推理（深度={CUSTOM_MAX_LEVEL}）...")
        start = time.time()
        success = ddar.solve(g, RULES, p, max_level=CUSTOM_MAX_LEVEL)
        elapsed = time.time() - start

        final = len(g.cache)
        print(f"耗时: {elapsed:.2f} 秒")
        print(f"最终关系数: {final}（+{final-initial}）")

        print("\n" + "=" * 60)
        if success:
            print("结果: 证明成功")
        else:
            print("结果: 未能证明")
            print("\n可能原因：")
            print("  1. 题目需要辅助构造")
            print("  2. 推理深度不足")
            print("  3. 题目格式有误")
        print("=" * 60)

    except Exception as e:
        print(f"\n出错: {e}")
        print("\n请检查题目格式是否正确")
        print("正确格式: 点 = 构造; ... ? 目标")

自定义题目测试

输入: a b c = triangle a b c; m = midpoint m a b; n = midpoint n a c ? para m n b c


解析成功
证明目标: <problem.Construction object at 0x7d2ab04f5a30>
初始关系数: 4

执行推理（深度=1000）...
耗时: 0.03 秒
最终关系数: 12（+8）

结果: 证明成功


---
# 附录

---

## A. 题目格式详解

AlphaGeometry 使用领域特定语言描述几何问题：

```
点列表 = 构造函数 参数; 更多构造 ? 证明目标
```

**分号** `;` 分隔多个构造语句

**问号** `?` 引入证明目标

---

## B. 常用构造函数

| 函数 | 参数 | 含义 |
|------|------|------|
| triangle | a b c | 三角形ABC |
| iso_triangle | a b c | 等腰三角形（AB=AC）|
| r_triangle | c a b | 直角三角形（角C=90度）|
| midpoint | m a b | M是AB的中点 |
| on_line | p a b | P在直线AB上 |
| on_circle | p o a | P在圆O上（O为圆心，过A）|
| on_tline | p a b c | P在过A垂直于BC的直线上 |
| on_pline | p a b c | P在过A平行于BC的直线上 |
| incenter | i a b c | I是三角形ABC的内心 |
| circumcenter | o a b c | O是三角形ABC的外心 |
| centroid | g a b c | G是三角形ABC的重心 |
| orthocenter | h a b c | H是三角形ABC的垂心 |
| foot | d p a b | D是P到直线AB的垂足 |

---

## C. 常用证明目标

| 谓词 | 参数 | 含义 |
|------|------|------|
| perp | a b c d | AB 垂直于 CD |
| para | a b c d | AB 平行于 CD |
| cong | a b c d | AB = CD（线段相等）|
| coll | a b c | A、B、C三点共线 |
| cyclic | a b c d | A、B、C、D四点共圆 |
| eqangle | a b c d e f g h | 角ABC = 角DEF |

---

## D. 示例题目

**1. 中点定理**
```
a b c = triangle a b c; m = midpoint m a b; n = midpoint n a c ? para m n b c
```

**2. 垂心性质**
```
a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c
```

**3. 外心性质**
```
a b c = triangle a b c; o = circumcenter o a b c ? cong o a o b
```

**4. 等腰三角形**
```
a b c = iso_triangle a b c; m = midpoint m b c ? perp a m b c
```