In [None]:
import os
import zipfile
# 解压测试集图片
zip_path = "./input/test_img_data.zip"
extract_dir = "./test_images"
os.makedirs(extract_dir, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

print(f"已解压图片到 {extract_dir}")

### Stage 1：OCR

In [None]:
#注意：安装命令运行结束后需重启内核才会更新
# 安装 paddlepaddle和paddleocr用于识别图片文字
pip install paddlepaddle==2.4.2 -i https://mirrors.aliyun.com/pypi/simple
!pip install --user paddlepaddle==2.4.2 -i https://mirrors.aliyun.com/pypi/simple
!pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddleocr
!pip uninstall -y numpy
!pip install --user -i https://pypi.tuna.tsinghua.edu.cn/simple numpy==1.26.4

In [None]:
import os
import json
import torch
from paddleocr import PaddleOCR


# 初始化 OCR
ocr = PaddleOCR(use_angle_cls=True, lang='ch')

# 读取原始 test_data.json
with open('./input/test_data.json', 'r', encoding='utf-8') as f:
    original_data = json.load(f)

# 遍历图片执行 OCR
updated_data = []

for item in original_data:
    filename = item.get("path", "") 
    image_path = os.path.join("./test_images", filename)

    if not os.path.exists(image_path):
        print(f"图片不存在: {image_path}")
        continue

    result = ocr.ocr(image_path, cls=True)
    text = ""

    if result and result[0] is not None:
#         print(result)
        for line in result:
            for box in line:
                text += box[1][0] + " "
    else:
        print(f"OCR结果为空: {image_path}")

    updated_item = {
        "fk_homework_id": item["fk_homework_id"],
        "path": item["path"],
        "source_text": text.strip(),
        "predict_text": "",
        "bounding_box_list": []
    }

    updated_data.append(updated_item)

# 保存新文件
with open('test_data_with_ocr.json', 'w', encoding='utf-8') as f:
    json.dump(updated_data, f, ensure_ascii=False, indent=2)

print("OCR完成，处理图片数量：", len(updated_data))

### Stage 2：Grammatical Error Correction(GEC)

In [None]:
#注意：安装命令运行结束后需重启内核才会更新
# 为避免依赖冲突，卸载 OCR 相关依赖
!pip uninstall -y paddleocr paddlepaddle

# 安装 pycorrector 和 transformers（纠错模型所需）
!pip install --user -U -i https://pypi.tuna.tsinghua.edu.cn/simple pycorrector
!pip install --user transformers==4.28.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
!pip install --user kenlm -i https://pypi.tuna.tsinghua.edu.cn/simple

In [None]:
import json
from pycorrector import Corrector

# 载入pycorrector 模型
model = Corrector(language_model_path='./models/people2014corpus_chars.klm')

# 加载 OCR 后的数据
with open("test_data_with_ocr.json", "r", encoding="utf-8") as f:
    data = json.load(f)

# 推理
results = []
for i, item in enumerate(data):
    src = item["source_text"]
    corrected_res = model.correct(src)
    corrected_text = corrected_res['target']

    new_item = dict(item)
    new_item["predict_text"] = corrected_text

    results.append(new_item)

    if i % 10 == 0:
        print(f"\n第 {i+1} 条样本纠错结果：")
        print("原文：", src[:100])
        print("纠错：", corrected_text[:100])


# 保存最终预测结果
with open("./output/predict.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("纠错完成，结果已保存为 ./output/predict.json")

In [None]:
# 将预测结果压缩后再提交
import os
path=os.getcwd()
newpath=path+"/output/"
os.chdir(newpath)
os.system('zip prediction.zip predict.json')
os.chdir(path)