In [None]:
from transformers import BertTokenizer, BertForTokenClassification
import torch
import util
import json
import re

In [None]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from zhpr.predict import DocumentDataset, merge_stride, decode_pred
from torch.utils.data import DataLoader
import torch

def restore_punctuation(text: str,
                        model_name: str = "p208p2002/zh-wiki-punctuation-restore",
                        window_size: int = 256,
                        step: int = 200,
                        batch_size: int = 4,
                        device: str = None) -> str:
    """
    Restore Chinese punctuation in `text` using the zh-wiki-punctuation-restore model,
    without corrupting original characters.
    """

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model = AutoModelForTokenClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.to(device)
    model.eval()

    dataset = DocumentDataset(text, window_size=window_size, step=step)
    dataloader = DataLoader(dataset=dataset, shuffle=False, batch_size=batch_size)

    all_predictions = []

    def predict_batch(batch_input_ids):
        outputs = model(input_ids=batch_input_ids.to(device))
        logits = outputs.logits
        pred_ids = torch.argmax(logits, dim=-1).detach().cpu()
        return pred_ids

    for batch in dataloader:
        if isinstance(batch, dict) and "input_ids" in batch:
            input_ids = batch["input_ids"]
        else:
            input_ids = batch
        pred_ids = predict_batch(input_ids)

        batch_out = []
        for batch_pred_ids, batch_input in zip(pred_ids, input_ids):
            tokens = tokenizer.convert_ids_to_tokens(batch_input)
            input_ids_list = batch_input.tolist()
            try:
                pad_index = input_ids_list.index(tokenizer.pad_token_id)
            except ValueError:
                pad_index = len(input_ids_list)
            tokens = tokens[:pad_index]
            preds = batch_pred_ids[:pad_index]
            pred_labels = [model.config.id2label[p.item()] for p in preds]
            out = list(zip(tokens, pred_labels))
            batch_out.append(out)
        all_predictions.extend(batch_out)

    merged = merge_stride(all_predictions, step)
    decoded = decode_pred(merged)

    # Fix: replace [UNK] with original characters
    result_chars = []
    orig_index = 0
    for token in decoded:
        if token == "[UNK]" and orig_index < len(text):
            result_chars.append(text[orig_index])
        else:
            result_chars.append(token)
        if token not in ["，", "。", "？", "！", "；", "、"]:  # only advance for non-inserted punctuation
            orig_index += 1

    result = ''.join(result_chars)
    return result


In [None]:
def object_list_to_string(object_list):
    merged_string = ""
    for obj in object_list:
        merged_string += obj["content"]
    return merged_string

def normalize_punctuation(text: str) -> str:
    # map half-width punctuation to full-width
    mapping = {
        ",": "，",
        ";": "；",
        "!": "！",
        "?": "？",
        ".": "。"
    }
    for en, zh in mapping.items():
        text = text.replace(en, zh)
    return text

def strip_punctuation(text: str) -> str:
    # Chinese + Western punctuations the model predicts
    return re.sub(r"[，。、！？；]", "", text)

In [None]:
def content_object_list_to_punctuated_string(content_object_list):
    merged_string = object_list_to_string(content_object_list)
    merged_string = normalize_punctuation(merged_string)
    merged_string = strip_punctuation(merged_string)
    punctuated_string = restore_punctuation(merged_string)
    return punctuated_string

In [None]:
def punctuate_content_file(start_index, end_index):
    for i in range(start_index, end_index):
        content_file = util.file_name_builder("." + util.FOLDER_PATH + "text_from_audio/",
                                           "Eph", "json", i)
        content_string_file = util.file_name_builder("." + util.FOLDER_PATH + "punctuated_content/",
                                           "Eph", "txt", i)
        with open(content_file, "r", encoding="utf-8") as f:
            content_object_list = json.load(f)
        punctuated_string = content_object_list_to_punctuated_string(content_object_list)
            
        with open(content_string_file, "w", encoding="utf-8") as f:
            f.write(punctuated_string)

In [None]:
punctuate_content_file(4, 6)

In [None]:
# latex part
#
#
#

In [None]:
def handle_latex_special_characters(c):
  match c:
    case "&" | "%" | "$" | "#" | "_" | "{" | "}":
      return "\\" + c
    case "~":
      return "\\textasciitilde"
    case "^":
      return "\\textasciicircum"
    case "\\":
      return "\\textbackslash"
    case _:
      return c

def modify_string_for_latex(s):
  new_string = ""
  for c in s:
    new_string += handle_latex_special_characters(c)
  return new_string

def string_title_to_latex(title_object_list, content_string, output_file):
    level_to_cmd = {
        0: "\\section",
        1: "\\subsection",
        2: "\\subsubsection",
    }

    latex_lines = [
        r"\documentclass[lang=cn,newtx,10pt,scheme=chinese]{elegantbook}",  # ctex for Chinese
        r"\begin{document}",
        r"\usepackage{graphicx}",
        r"\usepackage{fontspec}",
        r"\setmainfont{Times New Roman}[",
        "  Ligatures=TeX,"
        "  Script=Latin,",
        "  Script=Greek",
        r"]",
        "",
        r"\tcbset{",
        r"  mybox/.style={",
        "    colframe=black,",
        "    colback=white,",
        "    boxrule=0.8pt,",
        "    arc=0mm,",
        "    left=6pt,",
        "    right=6pt,",
        "    top=6pt,",
        "    bottom=6pt",
        r"  }",
        r"}",
        ""
    ]

    # list all titles
    for title_object in title_object_list:
        title_text = modify_string_for_latex(title_object["content"].strip())
        latex_command = level_to_cmd.get(title_object["level"], "\\subsubsection")
        latex_lines.append(latex_command + "{" + title_text + "}" + "\n")
    latex_lines.append("\n")
    # cut the content into lines
    latex_lines = format_content_string(content_string, latex_lines)
    
    latex_lines.append(r"\end{document}")

    with open(output_file, "w", encoding="utf-8") as f:
        f.write("".join(latex_lines))

In [None]:
def string_to_segment_list(text):
    parts = re.findall(r'.*?[，、。？！；]|[^，、。？！；]+', text)
    parts = [p.strip() for p in parts if p.strip()]
    return parts

In [None]:
# 1 sentence or k characters
def format_content_string(content_string, latex_lines, line_limit=40):
    print(f"start length of latex_lines is {len(latex_lines)}")
    segment_list = string_to_segment_list(content_string)
    current_line_word_count = 0
    current_line = ""
    for segment in segment_list:
        punctuation = segment[-1]
        # segment contains colon
        if punctuation == "。":
            current_line += segment
            latex_lines.append(current_line + "\n")
            current_line = ""
            current_line_word_count = 0
        # segment exceeds character limit
        elif current_line_word_count + len(segment) > line_limit:
            current_line += segment
            latex_lines.append(current_line + "\n")
            current_line = ""
            current_line_word_count = 0
        # normal case
        else:
            current_line += segment
            current_line_word_count += len(segment)
    print(f"end length of latex_lines is {len(latex_lines)}")
    return latex_lines

In [None]:
def string_file_to_latex(start_index, end_index):
    for i in range(start_index, end_index):
        title_file = util.file_name_builder("." + util.FOLDER_PATH + "title_from_ppt/", "Eph", "json", i)
        content_file = util.file_name_builder("." + util.FOLDER_PATH + "punctuated_content/", "Eph", "txt", i)
        output_file = util.file_name_builder("." + util.FOLDER_PATH + "latex/", "Eph", "tex", i)
        with open(title_file, "r", encoding="utf-8") as f:
            title_object_list = json.load(f)
        with open(content_file, "r", encoding="utf-8") as f:
            content_string = f.read()
        string_title_to_latex(title_object_list, content_string, output_file)

In [None]:
def fix_common_misspell(content_string):
    replacement_json_file = util.file_name_builder("." + util.FOLDER_PATH + "project_data/", "replace_misspelled_words", "json")
    with open(replacement_json_file, "r", encoding="utf-8") as f:
            replacement_dict = json.load(f)
    for old_word, new_word in replacement_dict.items():
        content_string = content_string.replace(old_word, new_word)
    return content_string

In [None]:
string_file_to_latex(4, 6)