In [3]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Union

# ---------- Term data types ----------
@dataclass(frozen=True)
class Term:
    pass

@dataclass(frozen=True)
class Var(Term):
    name: str
    def __str__(self): return self.name

@dataclass(frozen=True)
class Const(Term):
    name: str
    def __str__(self): return self.name

@dataclass(frozen=True)
class Func(Term):
    symbol: str
    args: List[Term]
    def __str__(self):
        if not self.args: return self.symbol
        return f"{self.symbol}(" + ", ".join(map(str, self.args)) + ")"

Subst = Dict[Var, Term]


# ---------- Helpers ----------
def is_var(t: Term) -> bool: return isinstance(t, Var)

def occurs_check(v: Var, t: Term, s: Subst) -> bool:
    """Return True if v occurs in t under substitution s (i.e., would make an infinite term)."""
    t = apply_subst(s, t)
    if t == v:
        return True
    if isinstance(t, Func):
        return any(occurs_check(v, a, s) for a in t.args)
    return False

def apply_subst(s: Subst, t: Term) -> Term:
    """Apply substitution s to term t."""
    if isinstance(t, Var) and t in s:
        return apply_subst(s, s[t])  # chase chains
    if isinstance(t, Func):
        return Func(t.symbol, [apply_subst(s, a) for a in t.args])
    return t  # Const or Var not in s

def extend(s: Subst, v: Var, t: Term) -> Subst:
    """Return a new substitution s ∪ {v ↦ t}, propagated through existing mappings."""
    # Apply to existing mappings to keep them reduced
    t = apply_subst(s, t)
    s2: Subst = {x: apply_subst({v: t}, apply_subst(s, x)) for x in s}  # normalize keys
    s2.update({x: apply_subst({v: t}, y) for x, y in s.items()})        # update values
    s2[v] = t
    return s2

class UnificationError(Exception):
    pass


# ---------- Unification (Algorithm: Unify(Ψ1, Ψ2)) ----------
def unify(t1: Term, t2: Term, s: Subst | None = None) -> Subst:
    """
    Unify t1 and t2 and return the most general unifier (MGU) as a substitution.
    Raises UnificationError on failure.
    """
    if s is None:
        s = {}

    # Step 1: Simplify under current substitution
    t1 = apply_subst(s, t1)
    t2 = apply_subst(s, t2)

    # 1a: Identical -> NIL (i.e., current substitution)
    if t1 == t2:
        return s

    # 1b/1c: One is a variable
    if is_var(t1):
        v = t1  # type: ignore
        if occurs_check(v, t2, s):
            raise UnificationError(f"Occurs check failed: {v} in {t2}")
        return extend(s, v, t2)

    if is_var(t2):
        v = t2  # type: ignore
        if occurs_check(v, t1, s):
            raise UnificationError(f"Occurs check failed: {v} in {t1}")
        return extend(s, v, t1)

    # 1d / Step 2 & 3: Both are function/predicate symbols
    if isinstance(t1, Func) and isinstance(t2, Func):
        if t1.symbol != t2.symbol:
            raise UnificationError(f"Predicate/function symbol mismatch: {t1.symbol} vs {t2.symbol}")
        if len(t1.args) != len(t2.args):
            raise UnificationError("Different number of arguments")

        # Step 5: Unify argument lists left-to-right and compose substitutions
        for a, b in zip(t1.args, t2.args):
            s = unify(a, b, s)
        return s

    # Constants that are different -> failure
    raise UnificationError(f"Cannot unify {t1} with {t2}")


# ---------- Convenience constructors ----------
def V(name: str) -> Var: return Var(name)
def C(name: str) -> Const: return Const(name)
def F(sym: str, *args: Term) -> Func: return Func(sym, list(args))


# ---------- Demo ----------
if __name__ == "__main__":
    # Example 1: f(x, g(a)) with f(b, g(y))  ->  { x ↦ b, y ↦ a }
    x = V("x")
    y = V("y")
    Apple = C("Apple")
    Riya = C("Riya")

    # Expressions
    t1 = F("Eats", x, Apple)
    t2 = F("Eats", Riya, y)

    # Try to unify
    try:
        mgu = unify(t1, t2)
        # Format the MGU with slashes
        formatted_mgu = "{ " + ", ".join(f"{k}/{v}" for k, v in mgu.items()) + " }"
        print("MGU (Most General Unifier):", formatted_mgu)
        print("After substitution:")
        print("t1 σ =", apply_subst(mgu, t1))
        print("t2 σ =", apply_subst(mgu, t2))
    except UnificationError as e:
        print("Failure:", e)



MGU (Most General Unifier): { x/Riya, y/Apple }
After substitution:
t1 σ = Eats(Riya, Apple)
t2 σ = Eats(Riya, Apple)
