## Preparation: imports, global vars, mounting drive (if on Colab)

In [2]:
import pandas as pd
import amrlib
import re
from typing import Tuple, List, Sequence, Dict
from tqdm.notebook import tqdm
import openai
from datetime import datetime
from io import StringIO
import penman
from penman import Graph
from penman.exceptions import PenmanError
from smatchpp import Smatchpp, solvers, preprocess, eval_statistics
from smatchpp.formalism.amr import tools as amrtools
import os, json
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv

load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


## Load test dataset (AMR 3.0)

In [None]:
df_test = pd.read_parquet("./data/test-00000-of-00001.parquet")
df_train = pd.read_parquet("./data/train-00000-of-00001.parquet")

sample_test = df_test.conversations.sample(50, random_state=42)
sample_train = df_train.conversations.sample(10, random_state=42)
full_train = df_train.conversations
full_train_texts = [item[0].get("content", "").replace("Generate an Abstract Meaning Representation (AMR) graph for the following sentence: ", "") for item in full_train]

## Load AMR parsing model

After a long and exhausting battle with Homebrew and various versions, I found a specific distribution of mip that must be installed for smatchpp to function properly. Without this, as far as I understand, you are doomed to get cbclib-related errors, at least on Mac with Apple Silicon (M1 in my case). This is an important step without which running parsing models locally won't work.

Source: https://github.com/coin-or/python-mip/issues/335

In [None]:
! pip3 install mip==1.16rc0

In [None]:
stog = amrlib.load_stog_model()

graphs = stog.parse_sents(['This is a test of the system.', 'This is a second sentence.'])
for graph in tqdm(graphs):
    print(graph)

## Define functions for AMR processing and validation

In [23]:
graph_standardizer = amrtools.AMRStandardizer()
printer = eval_statistics.ResultPrinter(score_type="micro", do_bootstrap=True, output_format="json")
ilp = solvers.ILP()
measure = Smatchpp(alignmentsolver=ilp, graph_standardizer=graph_standardizer, printer=printer)

def validate_amr(amr_str):
    """
    Parse AMR with penman and return a list of validation errors.
    An empty list means the AMR passed all checks.
    """
    errors = []

    try:
        g = penman.decode(amr_str)
    except Exception as e:
        return [f'PENMAN parse error: {e}']

    return errors


# AMR variable convention: usually a short lowercase letter with optional digits (e.g., d, x3, person2).
AMR_VAR_RE = re.compile(r'^[a-z][a-z0-9]*$')

def validate_amr(amr_str):
    """
    Parse AMR with penman and return a list of validation errors.
    An empty list means the AMR passed all checks.
    """
    errors: List[str] = []

    # Parseability
    try:
        g = penman.decode(amr_str)
    except Exception as e:
        print("PENMAN parse error")
        return [f'PENMAN parse error: {e}']

    triples = g.triples

    # Root check
    if g.top is None:
        errors.append("AMR has no top/root variable (missing instance for a root).")

    # Build instance map and variable set
    instance_of = {}
    var_set = set()

    for s, r, t in triples:
        var_set.add(s)
        if r == ':instance':
            if s in instance_of and t:
                errors.append(f"Multiple :instance triples for variable '{s}' "
                              f"({instance_of[s]} and {t}).")
            instance_of[s] = t

    # Every variable must (a) look like a proper AMR variable and (b) have exactly one :instance
    for v in var_set:
        # (a) variable naming
        if not AMR_VAR_RE.fullmatch(v):
            errors.append(
                f"Illegal variable name '{v}'. "
                "Variables must be simple lowercase letters with optional digits (e.g., 'd', 'x3'). "
            )
        # (b) instance triple present
        if v not in instance_of:
            errors.append(
                f"Variable '{v}' has no ':instance' triple (missing 'v / TYPE')."
            )

    # Ensure the top variable itself has an :instance
    if g.top is not None and g.top not in instance_of:
        errors.append(f"Top variable '{g.top}' has no ':instance' triple.")

    return errors

def is_valid_amr(amr_str):
    """
    Checks if AMR is valid and returns (valid, errors).
    """
    errors = validate_amr(amr_str)
    return (len(errors) == 0, errors)


def validate_and_score(
    gold_amrs,
    sys_amrs,
    report_path = "amr_validation_report.txt"
):
    """
    Writes a report to `report_path`.
    """
    assert len(gold_amrs) == len(sys_amrs), "Mismatch in AMR count"

    valid_gold = []
    valid_sys = []
    valid_count = 0

    buf = StringIO()

    for gold, sys in zip(gold_amrs, sys_amrs):
        ok, *rest = is_valid_amr(sys)
        if ok:
            valid_count += 1
            valid_gold.append(gold)
            valid_sys.append(sys)
        else:
            reason = f" | reason: {rest[0]}" if rest else ""
            buf.write(f"invalid AMR: {sys}{reason}\n")
            print(f"invalid AMR: {sys}{reason}\n")

    total = len(sys_amrs)
    valid_percent = 100.0 * valid_count / total if total else 0.0

    # Score only valid !!!
    smatch_score, optimization_status = measure.score_corpus(valid_gold, valid_sys)

    # Write summary + metrics
    buf.write(f"\nValid AMRs: {valid_percent:.2f}% ({valid_count}/{total})\n")
    buf.write("\nMetrics (valid only)\n")

    main = smatch_score.get("main", {})
    for name, d in main.items():
        val = float(d["result"])
        lo, hi = map(float, d.get("ci", (float('nan'), float('nan'))))
        buf.write(f"{name}: {val:.2f} (CI: [{lo:.2f}, {hi:.2f}])\n")

    # if optimization_status is not None:
    #     buf.write("\nOptimization status:\n")
    #     buf.write(str(optimization_status) + "\n")

    with open(report_path, "w", encoding="utf-8") as f:
        f.write(buf.getvalue())

    return valid_percent, smatch_score


# Example
gold_amrs = [
    "(a / approve-01 :ARG0 (p / person))",
    "(a / approve-01 :ARG0 (p / person))",
    "(x / person :mod (b / big))"
]
sys_amrs = [
    "(a / approve-01 :ARG0 (p / person))",  # valid and not correct
    "(a / approve-01 :ARG1 (p / person))",  # valid but not correct
    "(x / invalid"  # invalid: unbalanced parentheses
]

valid_pct, smatchpp_score = validate_and_score(gold_amrs, sys_amrs)

smatchpp_score

PENMAN parse error
invalid AMR: (x / invalid | reason: ['PENMAN parse error: \n  line 1\n    (x / invalid\n                ^\nDecodeError: Unexpected end of input']

Cbc3007W No integer variables
Cbc3007W No integer variables


{'main': {'F1': {'result': np.float64(87.5),
   'ci': (np.float64(75.0), np.float64(100.0))},
  'Precision': {'result': np.float64(87.5),
   'ci': (np.float64(75.0), np.float64(100.0))},
  'Recall': {'result': np.float64(87.5),
   'ci': (np.float64(75.0), np.float64(100.0))}}}

In [None]:
def get_golds(sample):
    golds = []
    for prompt, gold in tqdm(sample):
        gold_parse = gold['content']
        golds += [gold_parse]
    return golds

def get_amrs_from_model(stog, sample):
    data = []

    instr = "Generate an Abstract Meaning Representation (AMR) graph for the following sentence: "
    gens = []
    for prompt, gold in tqdm(sample):
        sent = prompt['content'].replace(instr, "")
        gold_parse = gold['content']
        gen_parse = stog.parse_sents([sent])[0]
        data.append({
            'sentence': sent,
            'gold_amr': gold_parse,
            'generated_amr': gen_parse
        })
        gens += [gen_parse]

    results_df = pd.DataFrame(data)
    results_df.to_csv("res_parse.tsv", sep="\t")
    return gens

In [None]:
PRICING = {
    "o3-2025-04-16": {"input": 2, "output": 8},
    "gpt-4o": {"input": 2.5, "output": 10},
    "gpt-5-2025-08-07": {"input": 1.25, "output": 10}
}
MODEL_NAME = "BAAI/bge-base-en-v1.5"
QUERY_INSTRUCTION = "Represent this sentence for searching relevant passages: "


def load_encoder(model_name):
    return SentenceTransformer(model_name)

# Caching
_ENCODER = None

def load_encoder(name="BAAI/bge-base-en-v1.5", device="cuda"):
    global _ENCODER
    if _ENCODER is None:
        _ENCODER = SentenceTransformer(name, device=device)
        _ENCODER.max_seq_length = 384
    return _ENCODER

def build_or_load_index(
    sentences,
    model,
    cache_dir = ".rag_cache",
    cache_name = "bge_base_en_v15_50k",
    force_rebuild = False,
):
    """
    sentences: list of short strings
    returns: (faiss_index, id2doc) where id2doc[i] = {"text": str}
    """
    os.makedirs(cache_dir, exist_ok=True)
    emb_path   = os.path.join(cache_dir, f"{cache_name}.npy")
    meta_path  = os.path.join(cache_dir, f"{cache_name}.meta.json")
    faiss_path = os.path.join(cache_dir, f"{cache_name}.faiss")

    id2doc = [{"text": s} for s in sentences]

    # Simple load-if-present cache
    if (
        not force_rebuild
        and os.path.exists(emb_path)
        and os.path.exists(meta_path)
        and os.path.exists(faiss_path)
    ):
        with open(meta_path, "r") as f:
            cached_meta = json.load(f)
        if len(cached_meta) == len(id2doc):
            index = faiss.read_index(faiss_path)
            return index, cached_meta

    # Encode all sentences
    emb = model.encode(
        [d["text"] for d in id2doc],
        batch_size=512,
        show_progress_bar=True,
        normalize_embeddings=True,
    ).astype(np.float32)

    np.save(emb_path, emb)
    with open(meta_path, "w") as f:
        json.dump(id2doc, f, ensure_ascii=False)

    dim = emb.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(emb)
    faiss.write_index(index, faiss_path)
    return index, id2doc

def retrieve_top_k(
    query_text,
    model,
    index,
    id2doc,
    top_k = 10,
):
    """
    returns: list of {"text", "score"} sorted by similarity (cosine)
    """
    q_text = QUERY_INSTRUCTION + query_text
    q = model.encode([q_text], normalize_embeddings=True).astype(np.float32)
    scores, idxs = index.search(q, top_k)
    scores, idxs = scores[0], idxs[0]

    results = []
    for s, i in zip(scores, idxs):
        if i < 0:
            continue
        results.append({"text": id2doc[int(i)]["text"], "score": float(s)})
    return results


def obtain_few_shot_examples(query, index, id2doc, model, top_k=10):
    top = retrieve_top_k(query, model, index, id2doc, top_k=top_k)
    top_texts = [r['text'] for r in top]
    texts_and_amrs = [full_train[full_train_texts.index(text)] for text in top_texts]
    return texts_and_amrs


def get_amrs_from_chatgpt(sample, system_prompt_init, dir_path, model="gpt-3.5-turbo", prepend_to_message="", append_to_message="", form_few_shot_prompt=False, **kwargs):
    """
    Generate AMRs from ChatGPT for a given sample of sentences.
    """
    client = openai.OpenAI(api_key=OPENAI_API_KEY)

    data, gens, stats = [], [], []

    for prompt, gold in tqdm(sample):
        user_message =  prepend_to_message + prompt['content'] + append_to_message
        gold_parse = gold['content']
        if form_few_shot_prompt:
            print("Loading encoder...")
            model_faiss = load_encoder()
            print("Building/loading index...")
            index, id2doc = build_or_load_index(full_train_texts, model_faiss)
            query = prompt['content'].replace("Generate an Abstract Meaning Representation (AMR) graph for the following sentence: ", "")
            few_shot_examples = obtain_few_shot_examples(query, index, id2doc, model_faiss)
            formatted_examples = [f"{item[0]['content']}\nAMR graph:{item[1]['content']}" for item in few_shot_examples]
            system_prompt = system_prompt_init + " Examples:" + "\n".join(formatted_examples)
        else:
            system_prompt = system_prompt_init

        try:
            response = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_message}
                ],
                temperature=0.1,
                **kwargs
            )

            gen_parse = response.choices[0].message.content.strip()

            if gen_parse.startswith("```"):
                gen_parse = gen_parse.split("```")[1]
            if gen_parse.startswith("`"):
                gen_parse = gen_parse.split("`")[1]
            if gen_parse.startswith("amr"):
                gen_parse = gen_parse[3:].strip()

            prompt_tokens = response.usage.prompt_tokens
            completion_tokens = response.usage.completion_tokens
            total_tokens  = response.usage.total_tokens
            print(f"total_tokens: {total_tokens}")

        except Exception as e:
            print(f"Error processing sentence '{user_message}': {e}")
            gen_parse = 'API call failed'
            prompt_tokens, completion_tokens, total_tokens = 0, 0, 0

        data.append({
            'sentence': user_message,
            'gold_amr': gold_parse,
            'generated_amr': gen_parse,
        })
        stats.append({
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": total_tokens
        })
        gens.append(gen_parse)

    prompt_tokens_final = sum(s.get("prompt_tokens", 0) or 0 for s in stats)
    completion_tokens_final = sum(s.get("completion_tokens", 0) or 0 for s in stats)

    final = {
        "prompt_tokens_final": prompt_tokens_final,
        "completion_tokens_final": completion_tokens_final,
        "total_tokens_final": sum(s.get("total_tokens", 0) or 0 for s in stats),
        "price": (prompt_tokens_final * PRICING[model]["input"] + completion_tokens_final * PRICING[model]["output"]) / 1000000
    }
    stats.append(final)
    results_df = pd.DataFrame(data)
    results_df.to_csv(f"{dir_path}/amr_parses.tsv", sep="\t")

    return stats, gens


def generate_and_eval(
        sample,
        system_prompt, 
        model, 
        comment="test", 
        prepend_to_message="", 
        append_to_message="", 
        form_few_shot_prompt=False, 
        **kwargs
    ):
    current_time = datetime.now()
    filename_time = current_time.strftime("%Y-%m-%d_%H-%M-%S")
    dir_path = f"/content/gdrive/My Drive/amr_parsing/results/{model}-{filename_time}_{comment}"
    os.makedirs(dir_path, exist_ok=True)
    stats, gens = get_amrs_from_chatgpt(sample, system_prompt, dir_path, model=model, prepend_to_message=prepend_to_message, append_to_message=append_to_message, form_few_shot_prompt=form_few_shot_prompt, **kwargs)
    with open(f"{dir_path}/stats.txt", "w", encoding="utf-8") as f:
        f.write("\n".join([str(d) for d in stats]))
    golds = get_golds(sample)
    valid_pct, smatchpp_score = validate_and_score(golds, gens, report_path=f"{dir_path}/evals.txt")
    return valid_pct, smatchpp_score

## Running inference and evaluation

Full results can be found in `results/` folder. Each run of each model has its own dedicated folder, formatted as `{model_name}-{run_timestamp}_{prompt_info}`. Each folder contains 1) `amr_parses.tsv` file with sentences, gold parses and generated parses and 2) `evals.txt` file with metrics and information on invalid graphs.

### gpt-4o

In [None]:
model="gpt-4o"

system_prompt_baseline = """You are an expert in Abstract Meaning Representation (AMR) parsing. Concept names must always have variables. Variables are lowercase letters with optional digits, e.g. (c1 / concept ...). Different concepts always get different variables, e.g. (c1 / concept ...) and (c2 / concept), even if the concept name is the same. To refer to the before-mentioned concept, you may use just the variable without brackets. Comments are absolutely not allowed. You only generate AMR parses and nothing else."""

valid_pct, smatchpp_score = generate_and_eval(sample_test, system_prompt_baseline, model=model, comment="add_instr", append_to_message="\nAMR graph:")

In [None]:
system_prompt = """You are an expert in Abstract Meaning Representation (AMR) parsing. Concept names must always have variables. Variables are lowercase letters with optional digits, e.g. (c1 / concept ...). Different concepts always get different variables, e.g. (c1 / concept ...) and (c2 / concept), even if the concept name is the same. To refer to the before-mentioned concept, you may use just the variable without brackets. Comments are absolutely not allowed. You only generate AMR parses and nothing else."""

examples = [f"{item[0]['content']}\nAMR graph:{item[1]['content']}" for item in sample_train]
system_prompt_10_shot = system_prompt + " Examples:" + "\n".join(examples)

valid_pct_10_shot, smatchpp_score_10_shot = generate_and_eval(sample_test, system_prompt_10_shot, model=model, comment="10_shot", append_to_message="\nAMR graph:")

In [None]:
with open("./AMR_detailed_instruction.txt", 'r') as f:
    system_prompt = f.read()

valid_pct, smatchpp_score = generate_and_eval(sample_test, system_prompt, model=model, comment="huge_prompt", append_to_message="\nAMR graph:")

### o3-2025-04-16

In [None]:
model="o3-2025-04-16"
system_prompt_baseline = """You are an expert in Abstract Meaning Representation (AMR) parsing. Concept names must always have variables. Variables are lowercase letters with optional digits, e.g. (c1 / concept ...). Different concepts always get different variables, e.g. (c1 / concept ...) and (c2 / concept), even if the concept name is the same. To refer to the before-mentioned concept, you may use just the variable without brackets. Comments are absolutely not allowed. You only generate AMR parses and nothing else."""

valid_pct, smatchpp_score = generate_and_eval(sample_test, system_prompt_baseline, model=model, comment="add_instr", append_to_message="\nAMR graph:")

In [None]:
examples = [f"{item[0]['content']}\nAMR graph:{item[1]['content']}" for item in sample_train]
system_prompt_10_shot = system_prompt_baseline + " Examples:" + "\n".join(examples)

valid_pct_10_shot, smatchpp_score_10_shot = generate_and_eval(sample_test, system_prompt_10_shot, model=model, comment="10_shot", append_to_message="\nAMR graph:")

In [None]:
with open("./data/AMR_detailed_instruction.txt", 'r') as f:
    system_prompt = f.read()

valid_pct, smatchpp_score = generate_and_eval(sample_test, system_prompt, model=model, comment="huge_prompt", append_to_message="\nAMR graph:")

### gpt-5-2025-08-07

In [None]:
model="gpt-5-2025-08-07"
system_prompt_baseline = """You are an expert in Abstract Meaning Representation (AMR) parsing. Concept names must always have variables. Variables are lowercase letters with optional digits, e.g. (c1 / concept ...). Different concepts always get different variables, e.g. (c1 / concept ...) and (c2 / concept), even if the concept name is the same. To refer to the before-mentioned concept, you may use just the variable without brackets. Comments are absolutely not allowed. You only generate AMR parses and nothing else."""

valid_pct, smatchpp_score = generate_and_eval(sample_test, system_prompt_baseline, model=model, comment="add_instr", append_to_message="\nAMR graph:", max_completion_tokens=5000)

In [None]:
examples = [f"{item[0]['content']}\nAMR graph:{item[1]['content']}" for item in sample_train]
system_prompt_10_shot = system_prompt_baseline + " Examples:" + "\n".join(examples)

valid_pct_10_shot, smatchpp_score_10_shot = generate_and_eval(sample_test, system_prompt_10_shot, model=model, comment="10_shot", append_to_message="\nAMR graph:")

In [None]:
with open("./data/AMR_detailed_instruction.txt", 'r') as f:
    system_prompt = f.read()

valid_pct, smatchpp_score = generate_and_eval(sample_test, system_prompt, model=model, comment="huge_prompt", append_to_message="\nAMR graph:")

## Offline parse evaluation

If the parses were saved successfully to `{some_folder}/amr_parses.tsv`, but were not evaluated, we can conduct evaluation post-hoc using the following function.

Also, here are the functions that let you get the best and the worst AMR parses and visualise them.

In [65]:
from amrlib.graph_processing.amr_plot import AMRPlot


def score_file(filepath, draft=False):
    df = pd.read_csv(filepath, sep='\t')

    golds = df['gold_amr'].tolist()
    gens = df['generated_amr'].tolist()
    if draft:
        gens = df['draft_amr'].tolist()


    gens_clean = ["(" + "(".join(text.split("(")[1:]).replace("AMR graph:", "") for text in gens]

    valid_pct, smatchpp_score = validate_and_score(golds, gens_clean)
    print(f"Valid AMRs: {valid_pct:.1f}%")
    print(f"Smatch++ score: {smatchpp_score}")
    return valid_pct, smatchpp_score


def get_sorted_amr_from_file(filepath, draft=False):
    """
    Returns a list of amr tuples sorted by F1 score.
    """
    df = pd.read_csv(filepath, sep="\t")
    sents =  df["sentence"].tolist()
    golds = df["gold_amr"].tolist()
    gens = (df["draft_amr"] if draft else df["generated_amr"]).tolist()
    gens_clean = ["(" + "(".join(text.split("(")[1:]).replace("AMR graph:", "") for text in gens]

    scored = []
    for i, (gold, gen, sent) in enumerate(zip(golds, gens_clean, sents)):
        valid_pct, smatchpp_score = validate_and_score([gold], [gen])
        f1 = smatchpp_score.get("main", {}).get("F1", {}).get("result", None)
        if f1 is None:
            f1 = 0.0
        scored.append((float(f1), sent, gen, gold, i))

    scored.sort(key=lambda x: x[0])

    return scored


def visualise_amr(graph, save_to):
    plot = AMRPlot(render_fn=f"{save_to}.gv", format="pdf")
    plot.build_from_graph(graph, debug=False)
    plot.render()

In [None]:
sorted_amrs = get_sorted_amr_from_file('./results/fin/gpt-5-2025-08-07-2025-09-29_21-27-34_langchain/amr_parses.tsv')

In [73]:
# get BEST AMR parses; plots saves to `pics/` folder

for i in range(1, 10):
    print(sorted_amrs[-i][0], sorted_amrs[-i][1])
    visualise_amr(sorted_amrs[i][2], f"pics/best/{i}_gen")
    visualise_amr(sorted_amrs[i][3], f"pics/best/{i}_gold")

100.0 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: 09/02/2010 13:25
100.0 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: 2008-10-08
100.0 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: Crime; weapons; international; money
100.0 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: Xinhua News Agency , Seoul , August 31st , by reporter Shuifu Tang
100.0 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: 2007-06-18
100.0 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: Senior Fellow at the International Institute for Strategic Studies mark Fitzpatrick stated that --
100.0 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: 08/02/2010 13:52
97.14 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: Premier Peng Li and K

In [74]:
# get WORST AMR parses; plots saves to `pics/` folder

for i in range(5):
    print(sorted_amrs[i][0], sorted_amrs[i][1])
    visualise_amr(sorted_amrs[i][2], f"pics/worst/{i}_gen")
    visualise_amr(sorted_amrs[i][3], f"pics/worst/{i}_gold")

0.0 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: m1456
46.81 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: Maritime officials in Kenya stated that critical details have yet to be agreed upon.
48.78 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: The list was made available on condition that neither the diplomat nor the diplomat's country be identified.
49.12 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: Some of these flights were innocent violations by ranchers in the Amazon flying between plantations.
50.0 Generate an Abstract Meaning Representation (AMR) graph for the following sentence: SUPPLIES WE NEED: comfortable nursing chairs and rockers, nursing foot stool, boppie, SLINGS and baby clothes, diapers.


In [None]:
score_file('./results/fin/gpt-5-2025-08-07-2025-09-29_21-27-34_langchain/amr_parses.tsv')

# LangChain self-correcting agent

### Imports

In [None]:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from pydantic import ValidationError,  BaseModel, Field
import json
from langchain_core.callbacks import BaseCallbackHandler
from typing import TypedDict, Optional, Union
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, START, END

text_llm_json = ChatOpenAI(
    model="gpt-4o",
    api_key=OPENAI_API_KEY,
    response_format={"type": "json_object"},
)

### Response classes (define the form of the model output)

In [None]:
class GenerateAMR(BaseModel):
    """Generate a valid AMR."""
    amr: str = Field(description="Valid AMR graph.")


class Reflection(BaseModel):
    score: int = Field(description="Score out of 5. 0 is for invalid or completely incorrect AMRs, 5 is for perfectly valid and sematically correct AMRs.")
    semantic_correctness: Union[str, None] = Field(description="Critique of AMR semantic correctness, e.g. roles, if any. Can be None.")
    structure: Union[str, None] = Field(description="Critique of AMR structure, if any. Can be None.")


class ReviseAnswer(BaseModel):
    """Revise your original answer to your question. Provide a well-formed, semantically faithful AMR and nothing else."""

    final_amr: str = Field(
        description="Well-formed, semantically faithful AMR, improved based on feedback."
    )


class AutoEvalOut(BaseModel):
    score: Union[str, int, float] = Field(...)
    comments: str = Field(..., min_length=1)

### Prompts

In [None]:
amr_system_prompt = """Rules of Abstract Meaning Representation (AMR) parsing. Concept names must always have variables. Variables are lowercase letters with optional digits, e.g. (c1 / concept ...). Different concepts always get different variables, e.g. (c1 / concept ...) and (c2 / concept), even if the concept name is the same. To refer to the before-mentioned concept, just the variable without brackets may be used."""

with open("./data/AMR_detailed_instruction.txt", 'r') as f:
    amr_detailed_prompt = f.read()

draft_prompt = ChatPromptTemplate.from_messages([
    ("system", amr_system_prompt),
    ("user", """Generate AMR parse for the following sentence: {question}

Here are examples of correct AMR parses: {examples}
Comments are absolutely not allowed. You only generate AMR parses and nothing else.""")
])

reflect_prompt = ChatPromptTemplate.from_messages([
    ("system", f"AMR parsong rules: {amr_detailed_prompt}"),
    ("user", """Here are examples of correct AMR parses: {examples}

Sentence: {question}
AMR draft: {draft}

Results of automatic evaluation: {auto_eval_comments}

Provide a score (from 0 to 5) for the given AMR draft and very concise critique of crucial errors in structure and semantic correctness if any.""")
])

revise_prompt = ChatPromptTemplate.from_messages([
    ("system", amr_system_prompt),
    ("user", """Here are examples of correct AMR parses: {examples}

Sentence: {question}
AMR draft: {draft}

Results of automatic evaluation: {auto_eval_comments}
Feedback:
{feedback}

Revise your previous answer using the new information.
- It is crucially important to fix issues reported in automatic evaluation, if any.
- You should use the previous critique to ensure that AMR is well-formed.
- You should use the previous critique to ensure that AMR is semantically faithful.
Provide the new correct AMR and absolutely nothing else. No comments allowed. Prioritise making AMR well-formed according to Results of automatic evaluation, if any.""")
])

In [None]:
draft_model = text_llm_json.with_structured_output(GenerateAMR)
reflect_model = text_llm_json.with_structured_output(Reflection)
revise_model = text_llm_json.with_structured_output(ReviseAnswer)

draft_chain   = draft_prompt   | draft_model
reflect_chain = reflect_prompt | reflect_model
revise_chain  = revise_prompt  | revise_model

### State, nodes, token counting helper function

In [None]:
class State(TypedDict, total=False):
    question: str
    draft: str
    auto_eval_comments: str
    feedback: str
    final_amr: str
    score: Union[str, int]
    semantic_correctness: str
    structure: str
    examples: str
    prompt_tokens_draft: Union[str, int, float]
    completion_tokens_draft: Union[str, int, float]
    total_tokens_draft: Union[str, int, float]
    prompt_tokens_reflect: Union[str, int, float]
    completion_tokens_reflect: Union[str, int, float]
    total_tokens_reflect: Union[str, int, float]
    prompt_tokens_revise: Union[str, int, float]
    completion_tokens_revise: Union[str, int, float]
    total_tokens_revise: Union[str, int, float]
    draft_obj: GenerateAMR
    reflect_obj: Reflection
    revise_obj: ReviseAnswer
    revise_count: int


class TokenCounter(BaseCallbackHandler):
    def __init__(self):
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.total_tokens = 0

    def on_llm_end(self, response, **kwargs):
        llm_output = getattr(response, "llm_output", {}) or {}
        usage = llm_output.get("token_usage", {}) or {}
        # OpenAI-style keys; fallbacks for other providers
        pt = usage.get("prompt_tokens") or usage.get("input_tokens") or 0
        ct = usage.get("completion_tokens") or usage.get("output_tokens") or 0
        tt = usage.get("total_tokens") or (pt + ct)
        self.prompt_tokens += int(pt)
        self.completion_tokens += int(ct)
        self.total_tokens += int(tt)


def draft_node(state):
    print("Loading encoder...")
    model_faiss = load_encoder()
    print("Building/loading index...")
    index, id2doc = build_or_load_index(full_train_texts, model_faiss)
    query = state["question"].replace("Generate an Abstract Meaning Representation (AMR) graph for the following sentence: ", "")
    few_shot_examples = obtain_few_shot_examples(query, index, id2doc, model_faiss)
    formatted_examples = [f"{item[0]['content']}\nAMR graph:{item[1]['content']}" for item in few_shot_examples]
    tc = TokenCounter()
    obj: GenerateAMR = draft_chain.invoke({
        "question": state["question"],
        "examples": formatted_examples
    }, config={"callbacks": [tc]})
    return {
        "draft_obj": obj,
        "draft": obj.amr,
        "examples": formatted_examples,
        "prompt_tokens_draft": tc.prompt_tokens,
        "completion_tokens_draft": tc.completion_tokens,
        "total_tokens_draft": tc.total_tokens
        }

def reflect_node(state: State) -> State:
    tc = TokenCounter()
    obj: Reflection = reflect_chain.invoke({
        "question": state["question"],
        "examples": state["examples"],
        "draft": state["draft"],
        "auto_eval": state["score"],
        "auto_eval_comments": state["auto_eval_comments"]
    }, config={"callbacks": [tc]})
    return {
        "reflect_obj": obj,
        "score": obj.score,
        "feedback": f"Score: {obj.score} Semantic_correctness: {obj.semantic_correctness} Structure: {obj.structure}",
        "prompt_tokens_reflect": tc.prompt_tokens,
        "completion_tokens_reflect": tc.completion_tokens,
        "total_tokens_reflect": tc.total_tokens
        }

def revise_node(state: State) -> State:
    c = state.get("revise_count", 0) + 1
    if int(state.get("score", 0)) == 5:
        return {
            "revise_count": c,
            "final_amr": state["draft"]
            }

    else:
        tc = TokenCounter()
        obj: ReviseAnswer = revise_chain.invoke({
            "question": state["question"],
            "draft": state["draft"],
            "feedback": state["feedback"],
            "examples": state["examples"],
            "auto_eval_comments": state["auto_eval_comments"]
        }, config={"callbacks": [tc]})
        return {
            "revise_obj": obj,
            "revise_count": c,
            "final_amr": obj.final_amr,
            "prompt_tokens_revise": tc.prompt_tokens,
            "completion_tokens_revise": tc.completion_tokens,
            "total_tokens_revise": tc.total_tokens
        }


def auto_eval(amr):
    """Evaluate generated AMRs (PENMAN and basic graph validation)"""
    ok, comment = is_valid_amr(amr)
    if not ok:
        score = 0
        comment = f"\nAMR is invalid: {comment}"
    else:
        score = "to be determined"
        comment = "No structural errors found."
    return AutoEvalOut(score=score, comments=comment)


def auto_eval_node(state: State) -> State:
    if state.get("final_amr"):
        result = auto_eval(state.get("final_amr"))
        result_dict = {
            "draft": state.get("final_amr"),
            "old_draft": state.get("draft")
        }
    else:
        result = auto_eval(state.get("draft"))
        result_dict = {}
    result_dict["auto_eval"] = result.model_dump()
    result_dict["score"] = result.score
    result_dict["auto_eval_comments"] = result.comments
    return result_dict

### Building the final LangChain graph

In [None]:
def route_after_revise(state):
    result = auto_eval(state.get("final_amr"))
    c = state.get("revise_count", 0)

    return "eval" if result.score == 0 and c <= 3 else "end"


builder = StateGraph(State)
builder.add_node("draft", draft_node)
builder.add_node("eval", auto_eval_node)
builder.add_node("reflect", reflect_node)
builder.add_node("revise", revise_node)

builder.add_edge(START, "draft")
builder.add_edge("draft", "eval")
builder.add_edge("eval", "reflect")
builder.add_edge("reflect", "revise")
builder.add_conditional_edges(
    "revise",
    route_after_revise,
    {"end": END, "eval": "eval"}
)

graph = builder.compile()

In [None]:
# Test example to demonstrate the output format

out = graph.invoke({"question": "Thank you"})

out

{'question': 'Thank you',
 'draft': '(t / thank-01\n      :ARG1 (y / you))',
 'auto_eval_comments': 'No structural errors found.',
 'feedback': 'Score: 5 Semantic_correctness: None Structure: None',
 'final_amr': '(t / thank-01\n      :ARG1 (y / you))',
 'score': 5,
 'examples': ['Generate an Abstract Meaning Representation (AMR) graph for the following sentence: Thank you\nAMR graph:(s / say-01\n      :ARG1 (t / thank-01\n            :ARG1 (y / you))\n      :ARG2 y)',
  'Generate an Abstract Meaning Representation (AMR) graph for the following sentence: Thank you\nAMR graph:(s / say-01\n      :ARG1 (t / thank-01\n            :ARG1 (y / you))\n      :ARG2 y)',
  'Generate an Abstract Meaning Representation (AMR) graph for the following sentence: thank you\nAMR graph:(t / thank-01\n      :ARG0 (i / i)\n      :ARG1 (y / you))',
  'Generate an Abstract Meaning Representation (AMR) graph for the following sentence: Thanking You.\nAMR graph:(t / thank-01\n      :ARG1 (y / you))',
  'Generat

### Running LangChain agent on test data

In [None]:
def get_amrs_from_langchain(sample, model, dir_path, prepend_to_message="", append_to_message="", **kwargs):
    """
    Generate AMRs from ChatGPT for a given sample of sentences.
    """
    data, gens, stats, drafts, full_logs = [], [], [], [], []

    for prompt, gold in tqdm(sample):
        user_message =  prepend_to_message + prompt['content'] + append_to_message
        gold_parse = gold['content']
        out = None
        try:

            out = graph.invoke({"question": user_message})

            gen_parse = out["final_amr"]

            if gen_parse.startswith("```"):
                gen_parse = gen_parse.split("```")[1]
            if gen_parse.startswith("`"):
                gen_parse = gen_parse.split("`")[1]
            if gen_parse.startswith("amr"):
                gen_parse = gen_parse[3:].strip()

            prompt_tokens_draft = out.get("prompt_tokens_draft", 0)
            completion_tokens_draft = out.get("completion_tokens_draft", 0)
            total_tokens_draft = out.get("total_tokens_draft", 0)
            prompt_tokens_reflect = out.get("prompt_tokens_reflect", 0)
            completion_tokens_reflect = out.get("completion_tokens_reflect", 0)
            total_tokens_reflect = out.get("total_tokens_reflect", 0)
            prompt_tokens_revise = out.get("prompt_tokens_revise", 0)
            completion_tokens_revise = out.get("completion_tokens_revise", 0)
            total_tokens_revise = out.get("total_tokens_revise", 0)
            print(f"total_tokens_draft, total_tokens_reflect, total_tokens_revise: {total_tokens_draft}, {total_tokens_reflect}, {total_tokens_revise}")

        except Exception as e:
            print(f"Error processing sentence '{user_message}': {e}")
            gen_parse = 'API call failed'
            prompt_tokens_draft, completion_tokens_draft, total_tokens_draft = 0, 0, 0
            prompt_tokens_reflect, completion_tokens_reflect, total_tokens_reflect = 0, 0, 0
            prompt_tokens_revise, completion_tokens_revise, total_tokens_revise = 0, 0, 0


        data_entry = {
            'sentence': user_message,
            'gold_amr': gold_parse,
            'generated_amr': gen_parse,
        }
        if out and "draft" in out:
            data_entry['draft_amr'] = out["draft"]
        else:
            data_entry['draft_amr'] = "N/A"


        data.append(data_entry)

        stats.append({
            "prompt_tokens_draft": prompt_tokens_draft,
            "completion_tokens_draft": completion_tokens_draft,
            "total_tokens_draft": total_tokens_draft,

            "prompt_tokens_reflect": prompt_tokens_reflect,
            "completion_tokens_reflect": completion_tokens_reflect,
            "total_tokens_reflect": total_tokens_reflect,

            "prompt_tokens_revise": prompt_tokens_revise,
            "completion_tokens_revise": completion_tokens_revise,
            "total_tokens_revise": total_tokens_revise
        })
        gens.append(gen_parse)
        if out and "draft" in out:
            drafts.append(out["draft"])
        else:
            drafts.append("N/A")

        full_logs.append(out)

    prompt_tokens_final = sum([s.get("prompt_tokens_draft", 0) or 0 for s in stats] + [s.get("prompt_tokens_reflect", 0) or 0 for s in stats] + [s.get("prompt_tokens_revise", 0) or 0 for s in stats])
    completion_tokens_final = sum([s.get("completion_tokens_draft", 0) or 0 for s in stats] + [s.get("completion_tokens_reflect", 0) or 0 for s in stats] + [s.get("completion_tokens_revise", 0) or 0 for s in stats])

    final = {
        "prompt_tokens_final": prompt_tokens_final,
        "completion_tokens_final": completion_tokens_final,
        "total_tokens_final": prompt_tokens_final + completion_tokens_final,
        "price": (prompt_tokens_final * PRICING[model]["input"] + completion_tokens_final * PRICING[model]["output"]) / 1000000
    }
    stats.append(final)
    results_df = pd.DataFrame(data)
    results_df.to_csv(f"{dir_path}/amr_parses.tsv", sep="\t")
    with open(f"{dir_path}/full_logs.txt", "w", encoding="utf-8") as f:
        f.write("\n".join([str(log) for log in full_logs]))

    return stats, gens, drafts


def generate_and_eval_langchain(sample, model, comment="test", prepend_to_message="", append_to_message="", **kwargs):
    current_time = datetime.now()
    filename_time = current_time.strftime("%Y-%m-%d_%H-%M-%S")
    dir_path = f"/content/gdrive/My Drive/amr_parsing/results/{model}-{filename_time}_{comment}"
    os.makedirs(dir_path, exist_ok=True)
    stats, gens, drafts = get_amrs_from_langchain(sample, model, dir_path, prepend_to_message=prepend_to_message, append_to_message=append_to_message, **kwargs)
    with open(f"{dir_path}/stats.txt", "w", encoding="utf-8") as f:
        f.write("\n".join([str(d) for d in stats]))
    golds = get_golds(sample)
    valid_pct, smatchpp_score = validate_and_score(golds, gens, report_path=f"{dir_path}/evals.txt")
    valid_pct_d, smatchpp_score_d = validate_and_score(golds, drafts, report_path=f"{dir_path}/draft_evals.txt")
    return valid_pct, smatchpp_score, valid_pct_d, smatchpp_score_d

### gpt-4o

In [None]:
generate_and_eval_langchain(sample_test, "gpt-4o", comment="langchain", prepend_to_message="", append_to_message="")

### gpt-5-2025-08-07

In [None]:
generate_and_eval_langchain(sample_test, "gpt-5-2025-08-07", comment="langchain", prepend_to_message="", append_to_message="")