In [1]:
from os import environ
from getpass import getpass

def _set_env(var: str):
    if not environ.get(var):
        environ[var] = getpass(f"{var}: ")

_set_env("COHERE_API_KEY")

In [2]:
from langchain_cohere import ChatCohere
llm = ChatCohere(
    model="command-a-03-2025"
)

In [None]:
from __future__ import annotations

from typing import Optional, Literal, List
from typing_extensions import TypedDict, Annotated
from pydantic import BaseModel, Field

Beat = Literal["A", "B", "C", "D", "E"]

class UserInput(BaseModel):
    scholarship_name: str
    program_type: Literal["Undergrad", "Graduate","Community Leadership"]
    goal_one_liner: str
    resume_points: list[str]


class PiiSpan(BaseModel):
    start: int
    end: int
    pii_type: str
    confidence: Optional[float] = None


class BeatPlanItem(BaseModel):
    beat: Beat
    missing: list[str]
    guidance: Optional[str] = None


class QuestionObject(BaseModel):
    beat: Beat
    question: str
    intent: str
    
class BeatPlanOut(BaseModel):
    items: list[BeatPlanItem]

class QuestionsOut(BaseModel):
    items: list[QuestionObject]

class ValidationReport(BaseModel):
    ok: bool
    errors: list[str] = Field(default_factory=list)
    warnings: list[str] = Field(default_factory=list)
    repairs_applied: list[str] = Field(default_factory=list)


def merge_questions_by_beat(left: dict[Beat, List[QuestionObject]], right: dict[Beat, List[QuestionObject]]):
    out = dict(left or {})
    for beat, qs in (right or {}).items():
        out.setdefault(beat, []).extend(qs)
    return out

class PipelineState(TypedDict, total=False):
    # Inputs
    user_input: UserInput

    # Governance front gate
    canonical_input: str
    pii_spans: list[PiiSpan]
    redacted_input: str

    # Planning
    beat_plan: list[BeatPlanItem]

    # Map outputs (per beat)
    
    questions_by_beat: Annotated[dict[Beat, list[QuestionObject]], merge_questions_by_beat]

    # Reduce outputs
    final_questions_by_beat: dict[Beat, list[QuestionObject]]
    
    # Validation outputs
    failed_beats: list[Beat]
    failed_reasons: dict[Beat, list[str]]


    # Reliability / repair
    validation_report: ValidationReport
    attempt_count: int

    regen_request: list[Beat]