# PyTorch GTN-Compatible Scratch Implementation

This notebook rewrites the original prototype into a cleaner, documented implementation that is **API-oriented toward GTN**.

## Goals
- Keep a GTN-like API (`Graph`, `linear_graph`, `compose`, `forward_score`, `backward`, ...).
- Preserve PyTorch autograd end-to-end.
- Keep code readable and easy to extend.

## Scope and caveats
- This is a pure-Python educational implementation, not a C++/CUDA replacement for speed.
- It implements a substantial GTN-compatible subset used by the examples in this repo.
- `compose` supports common epsilon behavior, but full production-grade epsilon filtering is more involved.
- `forward_score` / `viterbi_score` require acyclic graphs (matching common GTN constraints for these ops).


In [None]:
from __future__ import annotations

from dataclasses import dataclass
from collections import defaultdict, deque
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union

import torch
import graphviz

# Match GTN epsilon semantics
epsilon = -1
EPSILON = epsilon

TensorLike = Union[float, int, torch.Tensor]


def _to_weight_tensor(value: TensorLike, calc_grad: bool) -> torch.Tensor:
    """Convert value to a scalar torch tensor while preserving autograd links when possible."""
    if isinstance(value, torch.Tensor):
        t = value.reshape(())
        if not calc_grad and t.requires_grad:
            t = t.detach()
        return t
    return torch.tensor(float(value), dtype=torch.float32, requires_grad=calc_grad)


@dataclass
class Arc:
    src: int
    dst: int
    ilabel: int
    olabel: int
    weight: torch.Tensor


class Graph:
    """
    GTN-like weighted graph/transducer.

    Key compatibility points:
    - Graph(calc_grad=True)
    - add_node(start=False, accept=False)
    - add_arc(src, dst, ilabel, olabel=ilabel, weight=0.0)
    - num_nodes(), num_arcs(), num_start(), num_accept()
    - labels_to_list(input=True), weights_to_list(), set_weights(...)
    - item() for scalar graphs
    - grad() returns a graph with same topology and gradient weights
    - zero_grad()
    """

    def __init__(self, calc_grad: bool = True):
        self.calc_grad = bool(calc_grad)
        self._num_nodes = 0
        self._start_nodes: Set[int] = set()
        self._accept_nodes: Set[int] = set()
        self._arcs: List[Arc] = []

    # -------- Node/arc construction --------
    def add_node(self, start: bool = False, accept: bool = False) -> int:
        idx = self._num_nodes
        self._num_nodes += 1
        if start:
            self._start_nodes.add(idx)
        if accept:
            self._accept_nodes.add(idx)
        return idx

    def add_arc(
        self,
        src: int,
        dst: int,
        ilabel: int,
        olabel: Optional[int] = None,
        weight: TensorLike = 0.0,
    ) -> None:
        if olabel is None:
            olabel = ilabel
        w = _to_weight_tensor(weight, self.calc_grad)
        self._arcs.append(Arc(src, dst, int(ilabel), int(olabel), w))
        self._num_nodes = max(self._num_nodes, src + 1, dst + 1)

    # -------- Introspection --------
    def num_nodes(self) -> int:
        return self._num_nodes

    def num_arcs(self) -> int:
        return len(self._arcs)

    def num_start(self) -> int:
        return len(self._start_nodes)

    def num_accept(self) -> int:
        return len(self._accept_nodes)

    @property
    def start_nodes(self) -> Set[int]:
        return set(self._start_nodes)

    @property
    def accept_nodes(self) -> Set[int]:
        return set(self._accept_nodes)

    def labels_to_list(self, input: bool = True) -> List[int]:
        return [a.ilabel if input else a.olabel for a in self._arcs]

    def weights_to_list(self) -> List[float]:
        return [float(a.weight.detach().cpu().item()) for a in self._arcs]

    def weights_to_numpy(self):
        import numpy as np
        return np.array(self.weights_to_list(), dtype=np.float32)

    def set_weights(self, values: Union[Sequence[float], torch.Tensor]) -> None:
        if isinstance(values, torch.Tensor):
            vals = values.reshape(-1)
        else:
            vals = torch.tensor(list(values), dtype=torch.float32)
        if vals.numel() != len(self._arcs):
            raise ValueError(f"Expected {len(self._arcs)} weights, got {vals.numel()}")
        for i, arc in enumerate(self._arcs):
            arc.weight = _to_weight_tensor(vals[i], self.calc_grad)

    def weight(self, i: int) -> float:
        return float(self._arcs[i].weight.detach().cpu().item())

    # -------- Autograd --------
    def _weight_tensors(self) -> List[torch.Tensor]:
        return [a.weight for a in self._arcs]

    def grad(self) -> "Graph":
        g = Graph(calc_grad=False)
        for i in range(self._num_nodes):
            g.add_node(start=(i in self._start_nodes), accept=(i in self._accept_nodes))
        for arc in self._arcs:
            gw = 0.0 if arc.weight.grad is None else float(arc.weight.grad.detach().cpu().item())
            g.add_arc(arc.src, arc.dst, arc.ilabel, arc.olabel, gw)
        return g

    def zero_grad(self) -> None:
        for w in self._weight_tensors():
            if w.grad is not None:
                w.grad.zero_()

    # -------- Scalar utility --------
    def is_scalar_graph(self) -> bool:
        if self._num_nodes != 2 or len(self._arcs) != 1:
            return False
        a = self._arcs[0]
        return (0 in self._start_nodes) and (1 in self._accept_nodes) and a.src == 0 and a.dst == 1 and a.ilabel == epsilon and a.olabel == epsilon

    def item(self) -> float:
        if not self.is_scalar_graph():
            raise ValueError("item() is only valid for scalar graphs")
        return float(self._arcs[0].weight.detach().cpu().item())

    # -------- Debug/visualization --------
    def __repr__(self) -> str:
        return f"Graph(nodes={self._num_nodes}, arcs={len(self._arcs)}, start={sorted(self._start_nodes)}, accept={sorted(self._accept_nodes)})"

    def draw(self, filename: str, format: str = "svg", view: bool = False, label_map: Optional[Dict[str, int]] = None):
        inv = None
        if label_map is not None:
            inv = {v: k for k, v in label_map.items()}

        def fmt_label(x: int) -> str:
            if x == epsilon:
                return "ε"
            if inv is not None and x in inv:
                return str(inv[x])
            return str(x)

        dot = graphviz.Digraph(filename)
        dot.attr(rankdir="LR")
        for i in range(self._num_nodes):
            shape = "doublecircle" if i in self._accept_nodes else "circle"
            color = "blue" if i in self._start_nodes else "black"
            dot.node(str(i), str(i), shape=shape, color=color)

        for a in self._arcs:
            lbl = f"{fmt_label(a.ilabel)}:{fmt_label(a.olabel)} / {float(a.weight.detach().cpu().item()):.3f}"
            dot.edge(str(a.src), str(a.dst), label=lbl)

        dot.render(filename, format=format, view=view, cleanup=True)
        return dot


# ---------------- Constructors ----------------
def scalar_graph(weight: TensorLike, calc_grad: bool = True) -> Graph:
    g = Graph(calc_grad=calc_grad)
    s = g.add_node(start=True)
    t = g.add_node(accept=True)
    g.add_arc(s, t, epsilon, epsilon, weight)
    return g


def linear_graph(m: int, n: int, calc_grad: bool = True) -> Graph:
    """Create a length-m linear acceptor with n labels per timestep."""
    g = Graph(calc_grad=calc_grad)
    for i in range(m + 1):
        g.add_node(start=(i == 0), accept=(i == m))
    for t in range(m):
        for lab in range(n):
            g.add_arc(t, t + 1, lab, lab, 0.0)
    return g


# ---------------- Graph transforms ----------------
def clone(g: Graph) -> Graph:
    out = Graph(calc_grad=g.calc_grad)
    for i in range(g.num_nodes()):
        out.add_node(start=(i in g.start_nodes), accept=(i in g.accept_nodes))
    for a in g._arcs:
        out.add_arc(a.src, a.dst, a.ilabel, a.olabel, a.weight.clone())
    return out


def project_input(g: Graph) -> Graph:
    out = Graph(calc_grad=g.calc_grad)
    for i in range(g.num_nodes()):
        out.add_node(start=(i in g.start_nodes), accept=(i in g.accept_nodes))
    for a in g._arcs:
        out.add_arc(a.src, a.dst, a.ilabel, a.ilabel, a.weight)
    return out


def project_output(g: Graph) -> Graph:
    out = Graph(calc_grad=g.calc_grad)
    for i in range(g.num_nodes()):
        out.add_node(start=(i in g.start_nodes), accept=(i in g.accept_nodes))
    for a in g._arcs:
        out.add_arc(a.src, a.dst, a.olabel, a.olabel, a.weight)
    return out


def remove(g: Graph, ilabel: int = epsilon, olabel: Optional[int] = None) -> Graph:
    if olabel is None:
        olabel = ilabel
    out = Graph(calc_grad=g.calc_grad)
    for i in range(g.num_nodes()):
        out.add_node(start=(i in g.start_nodes), accept=(i in g.accept_nodes))
    for a in g._arcs:
        if not (a.ilabel == ilabel and a.olabel == olabel):
            out.add_arc(a.src, a.dst, a.ilabel, a.olabel, a.weight)
    return out


def union(graphs: Sequence[Graph]) -> Graph:
    out = Graph(calc_grad=any(g.calc_grad for g in graphs))
    for g in graphs:
        node_map = {}
        for i in range(g.num_nodes()):
            node_map[i] = out.add_node(start=(i in g.start_nodes), accept=(i in g.accept_nodes))
        for a in g._arcs:
            out.add_arc(node_map[a.src], node_map[a.dst], a.ilabel, a.olabel, a.weight)
    return out


def concat(g1: Graph, g2: Graph) -> Graph:
    out = Graph(calc_grad=(g1.calc_grad or g2.calc_grad))

    m1 = {}
    for i in range(g1.num_nodes()):
        m1[i] = out.add_node(start=(i in g1.start_nodes), accept=False)
    m2 = {}
    for i in range(g2.num_nodes()):
        m2[i] = out.add_node(start=False, accept=(i in g2.accept_nodes))

    for a in g1._arcs:
        out.add_arc(m1[a.src], m1[a.dst], a.ilabel, a.olabel, a.weight)
    for a in g2._arcs:
        out.add_arc(m2[a.src], m2[a.dst], a.ilabel, a.olabel, a.weight)

    for a1 in g1.accept_nodes:
        for s2 in g2.start_nodes:
            out.add_arc(m1[a1], m2[s2], epsilon, epsilon, 0.0)

    return out


def closure(g: Graph) -> Graph:
    out = Graph(calc_grad=g.calc_grad)
    m = {}
    for i in range(g.num_nodes()):
        m[i] = out.add_node(start=False, accept=False)
    for a in g._arcs:
        out.add_arc(m[a.src], m[a.dst], a.ilabel, a.olabel, a.weight)

    c = out.add_node(start=True, accept=True)
    for s in g.start_nodes:
        out.add_arc(c, m[s], epsilon, epsilon, 0.0)
    for a in g.accept_nodes:
        out.add_arc(m[a], c, epsilon, epsilon, 0.0)
    return out


def _out_arcs_by_src(g: Graph) -> Dict[int, List[Arc]]:
    d = defaultdict(list)
    for a in g._arcs:
        d[a.src].append(a)
    return d


def compose(g1: Graph, g2: Graph) -> Graph:
    """
    Simplified epsilon-aware composition suitable for many GTN notebook examples.
    """
    out = Graph(calc_grad=(g1.calc_grad or g2.calc_grad))
    o1 = _out_arcs_by_src(g1)
    o2 = _out_arcs_by_src(g2)

    state_map: Dict[Tuple[int, int], int] = {}
    q = deque()

    for s1 in g1.start_nodes:
        for s2 in g2.start_nodes:
            p = (s1, s2)
            idx = out.add_node(start=True, accept=(s1 in g1.accept_nodes and s2 in g2.accept_nodes))
            state_map[p] = idx
            q.append(p)

    while q:
        u1, u2 = q.popleft()
        src_idx = state_map[(u1, u2)]

        # 1) consume epsilon-output arcs in g1
        for a1 in o1.get(u1, []):
            if a1.olabel == epsilon:
                v = (a1.dst, u2)
                if v not in state_map:
                    state_map[v] = out.add_node(start=False, accept=(v[0] in g1.accept_nodes and v[1] in g2.accept_nodes))
                    q.append(v)
                out.add_arc(src_idx, state_map[v], a1.ilabel, epsilon, a1.weight)

        # 2) consume epsilon-input arcs in g2
        for a2 in o2.get(u2, []):
            if a2.ilabel == epsilon:
                v = (u1, a2.dst)
                if v not in state_map:
                    state_map[v] = out.add_node(start=False, accept=(v[0] in g1.accept_nodes and v[1] in g2.accept_nodes))
                    q.append(v)
                out.add_arc(src_idx, state_map[v], epsilon, a2.olabel, a2.weight)

        # 3) regular label matches
        for a1 in o1.get(u1, []):
            if a1.olabel == epsilon:
                continue
            for a2 in o2.get(u2, []):
                if a2.ilabel == epsilon:
                    continue
                if a1.olabel == a2.ilabel:
                    v = (a1.dst, a2.dst)
                    if v not in state_map:
                        state_map[v] = out.add_node(start=False, accept=(v[0] in g1.accept_nodes and v[1] in g2.accept_nodes))
                        q.append(v)
                    out.add_arc(src_idx, state_map[v], a1.ilabel, a2.olabel, a1.weight + a2.weight)

    return out


def intersect(g1: Graph, g2: Graph) -> Graph:
    """Intersection for acceptors (matches input labels)."""
    p = compose(project_input(g1), project_input(g2))
    return project_input(p)


# ---------------- Scoring ----------------
def _topological_order_or_raise(g: Graph) -> List[int]:
    indeg = [0] * g.num_nodes()
    out_adj = defaultdict(list)
    for a in g._arcs:
        out_adj[a.src].append(a.dst)
        indeg[a.dst] += 1

    q = deque([i for i in range(g.num_nodes()) if indeg[i] == 0])
    order = []
    while q:
        u = q.popleft()
        order.append(u)
        for v in out_adj[u]:
            indeg[v] -= 1
            if indeg[v] == 0:
                q.append(v)

    if len(order) != g.num_nodes():
        raise ValueError("forward/viterbi currently require acyclic graphs in this implementation")
    return order


def _logaddexp(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return torch.logaddexp(a, b)


def forward_score(g: Graph) -> Graph:
    if g.num_nodes() == 0:
        return scalar_graph(torch.tensor(float("-inf"), dtype=torch.float32), calc_grad=g.calc_grad)

    order = _topological_order_or_raise(g)
    out_arcs = _out_arcs_by_src(g)

    neg_inf = torch.tensor(float("-inf"), dtype=torch.float32)
    alpha = [neg_inf for _ in range(g.num_nodes())]
    for s in g.start_nodes:
        alpha[s] = torch.tensor(0.0, dtype=torch.float32)

    for u in order:
        if torch.isneginf(alpha[u]):
            continue
        for a in out_arcs.get(u, []):
            cand = alpha[u] + a.weight
            alpha[a.dst] = cand if torch.isneginf(alpha[a.dst]) else _logaddexp(alpha[a.dst], cand)

    acc = [alpha[t] for t in g.accept_nodes if not torch.isneginf(alpha[t])]
    if not acc:
        out = neg_inf
    else:
        out = acc[0]
        for x in acc[1:]:
            out = _logaddexp(out, x)

    return scalar_graph(out, calc_grad=g.calc_grad)


def viterbi_score(g: Graph) -> Graph:
    if g.num_nodes() == 0:
        return scalar_graph(torch.tensor(float("-inf"), dtype=torch.float32), calc_grad=g.calc_grad)

    order = _topological_order_or_raise(g)
    out_arcs = _out_arcs_by_src(g)

    neg_inf = torch.tensor(float("-inf"), dtype=torch.float32)
    alpha = [neg_inf for _ in range(g.num_nodes())]
    for s in g.start_nodes:
        alpha[s] = torch.tensor(0.0, dtype=torch.float32)

    for u in order:
        if torch.isneginf(alpha[u]):
            continue
        for a in out_arcs.get(u, []):
            cand = alpha[u] + a.weight
            alpha[a.dst] = cand if torch.isneginf(alpha[a.dst]) else torch.maximum(alpha[a.dst], cand)

    acc = [alpha[t] for t in g.accept_nodes if not torch.isneginf(alpha[t])]
    out = neg_inf if not acc else torch.stack(acc).max()
    return scalar_graph(out, calc_grad=g.calc_grad)


def viterbi_path(g: Graph) -> Graph:
    """Returns one best path as a new graph."""
    order = _topological_order_or_raise(g)
    out_arcs = _out_arcs_by_src(g)

    neg_inf = torch.tensor(float("-inf"), dtype=torch.float32)
    score = [neg_inf for _ in range(g.num_nodes())]
    bp: List[Optional[Tuple[int, Arc]]] = [None for _ in range(g.num_nodes())]

    for s in g.start_nodes:
        score[s] = torch.tensor(0.0, dtype=torch.float32)

    for u in order:
        if torch.isneginf(score[u]):
            continue
        for a in out_arcs.get(u, []):
            cand = score[u] + a.weight
            if torch.isneginf(score[a.dst]) or (cand > score[a.dst]):
                score[a.dst] = cand
                bp[a.dst] = (u, a)

    best_accept = None
    best_score = neg_inf
    for t in g.accept_nodes:
        if not torch.isneginf(score[t]) and (best_accept is None or score[t] > best_score):
            best_accept = t
            best_score = score[t]

    if best_accept is None:
        return Graph(calc_grad=g.calc_grad)

    arcs_rev: List[Arc] = []
    cur = best_accept
    while bp[cur] is not None:
        prev, a = bp[cur]
        arcs_rev.append(a)
        cur = prev

    path_arcs = list(reversed(arcs_rev))
    p = Graph(calc_grad=g.calc_grad)
    p.add_node(start=True)
    for i, a in enumerate(path_arcs, start=1):
        p.add_node(accept=(i == len(path_arcs)))
        p.add_arc(i - 1, i, a.ilabel, a.olabel, a.weight)
    return p


# ---------------- Scalar ops and autograd entry ----------------
def _scalar_tensor(g: Graph) -> torch.Tensor:
    if not g.is_scalar_graph():
        raise ValueError("Expected scalar graph")
    return g._arcs[0].weight


def negate(g: Graph) -> Graph:
    return scalar_graph(-_scalar_tensor(g), calc_grad=g.calc_grad)


def add(g1: Graph, g2: Graph) -> Graph:
    if g1.is_scalar_graph() and g2.is_scalar_graph():
        return scalar_graph(_scalar_tensor(g1) + _scalar_tensor(g2), calc_grad=(g1.calc_grad or g2.calc_grad))
    raise NotImplementedError("Graph-wise add is not implemented; only scalar add is supported.")


def subtract(g1: Graph, g2: Graph) -> Graph:
    if g1.is_scalar_graph() and g2.is_scalar_graph():
        return scalar_graph(_scalar_tensor(g1) - _scalar_tensor(g2), calc_grad=(g1.calc_grad or g2.calc_grad))
    raise NotImplementedError("Graph-wise subtract is not implemented; only scalar subtract is supported.")


def backward(g: Graph, retain_graph: bool = False) -> None:
    _scalar_tensor(g).backward(retain_graph=retain_graph)


# ---------------- Optional namespace shim ----------------
class gtn:
    Graph = Graph
    epsilon = epsilon

    scalar_graph = staticmethod(scalar_graph)
    linear_graph = staticmethod(linear_graph)

    clone = staticmethod(clone)
    union = staticmethod(union)
    concat = staticmethod(concat)
    closure = staticmethod(closure)
    compose = staticmethod(compose)
    intersect = staticmethod(intersect)

    project_input = staticmethod(project_input)
    project_output = staticmethod(project_output)
    remove = staticmethod(remove)

    forward_score = staticmethod(forward_score)
    viterbi_score = staticmethod(viterbi_score)
    viterbi_path = staticmethod(viterbi_path)

    negate = staticmethod(negate)
    add = staticmethod(add)
    subtract = staticmethod(subtract)
    backward = staticmethod(backward)


## Quick API sanity checks

The next cells verify behavior with GTN-style usage patterns.


In [None]:
# Basic graph creation + score + backward

g1 = gtn.Graph()
s0 = g1.add_node(start=True)
s1 = g1.add_node(accept=True)
g1.add_arc(s0, s1, 1, 1, torch.tensor(0.5, requires_grad=True))

g2 = gtn.Graph()
t0 = g2.add_node(start=True)
t1 = g2.add_node(accept=True)
g2.add_arc(t0, t1, 1, 1, torch.tensor(-0.2, requires_grad=True))

c = gtn.compose(g1, g2)
print("compose:", c)

score = gtn.forward_score(c)
print("score scalar graph:", score, "item=", score.item())

loss = gtn.negate(score)
gtn.backward(loss)

print("g1 grad graph weights:", g1.grad().weights_to_list())
print("g2 grad graph weights:", g2.grad().weights_to_list())


In [None]:
# Optional-diacritic WFST (GTN semantics: epsilon = -1)

EPSILON = gtn.epsilon

def build_arabic_char_with_diacritic_wfst(arabic_chars, diacritics):
    g = gtn.Graph(calc_grad=False)

    # Keep epsilon separate from regular symbols.
    label_map = {"<eps>": EPSILON}
    next_idx = 0

    for label in arabic_chars + diacritics:
        if label not in label_map:
            while next_idx == EPSILON:
                next_idx += 1
            label_map[label] = next_idx
            next_idx += 1

    start_node = g.add_node(start=True)
    intermediate_node = g.add_node()
    accept_node = g.add_node(accept=True)

    for char in arabic_chars:
        lab = label_map[char]
        g.add_arc(start_node, intermediate_node, lab, lab, 0.0)

    for d in diacritics:
        lab = label_map[d]
        g.add_arc(intermediate_node, accept_node, lab, lab, 0.0)

    # Optional skip of diacritic
    g.add_arc(intermediate_node, accept_node, EPSILON, EPSILON, 0.0)

    return g, label_map

arabic_chars = ["ا", "ب", "ت", "ث"]
diacritics = ["fatha", "damma", "kasra", "shada", "sukun"]

arabic_wfst, label_map = build_arabic_char_with_diacritic_wfst(arabic_chars, diacritics)
print("label_map:", label_map)
print(arabic_wfst)
arabic_wfst.draw("arabic_wfst_corrected", format="png", view=False, label_map=label_map)


## Notes for extending toward full parity

To move from this notebook implementation to a production-complete GTN alternative, next priorities are:

1. Full epsilon-filter composition identical to GTN behavior.
2. Batched operations (`parallel_for`, map variants).
3. Full criterion module parity (`ctc_loss`, ASG helpers, etc.).
4. CUDA kernels / fused semiring operations for performance.
5. Serialization helpers and strict test parity against `bindings/python/test`.

This notebook now provides a clean, documented, autograd-correct base to build on.
