In [None]:
!pip -q install "transformers>=4.41.0" "peft>=0.11.1" "accelerate>=0.30.0" "bitsandbytes>=0.43.1" pypdf

In [None]:
from huggingface_hub import login
login(new_session=False)

In [None]:
# Base Mistral 7B Instruct (public)
BASE_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"

# Example: If your adapter is in Google Drive, mount and point to it.
# from google.colab import drive
# drive.mount('/content/drive')
# ADAPTER_PATH = "/content/drive/MyDrive/path/to/your_adapter"
ADAPTER_PATH = None  # set to your adapter dir, or leave None to test base model

MAX_NEW_TOKENS = 768  # raise to 1024 if your outputs are clipped


In [None]:
#load model with 4-bit (CUDA) + tokenizer
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                # use 4-bit to fit on T4
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=dtype,
)

if ADAPTER_PATH:
    print("Loading LoRA adapter:", ADAPTER_PATH)
    model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
else:
    print("No adapter provided; using base model only.")
    model = base_model

model.eval()


In [None]:
#PDF upload + text extraction
from google.colab import files
from pypdf import PdfReader

def read_pdf_text(path: str) -> str:
    reader = PdfReader(path)
    return "\n".join([(p.extract_text() or "") for p in reader.pages])

print("Upload a PDF (invoice or SOW)…")
uploaded = files.upload()
pdf_path = list(uploaded.keys())[0]
doc_text_raw = read_pdf_text(pdf_path)
print("Extracted chars:", len(doc_text_raw))
print(doc_text_raw[:1200])


In [None]:
#Text sanitization + chunking helpers
import re, json

def clean_pdf_text(raw: str, limit=16000) -> str:
    s = raw
    # Kill any literal [INST]/[/INST] that might appear in OCR/headers/footers
    s = re.sub(r"\[\s*/?\s*INST\s*\]", " ", s, flags=re.IGNORECASE)
    # Remove bracket tags like [TOC], [FOOTER], etc.
    s = re.sub(r"\[[A-Za-z0-9_/:\- ]{1,12}\]", " ", s)
    # Remove page markers (Page 1, Page: 1, Page 3 of 5…)
    s = re.sub(r"\bPage\s*[:#]?\s*\d+(\s*of\s*\d+)?\b.*", " ", s, flags=re.IGNORECASE)
    # Remove all-caps short header/footer lines
    s = "\n".join(
        line for line in s.splitlines()
        if not re.fullmatch(r"[A-Z0-9 \-_/]{2,40}", line.strip())
    )
    # Normalize whitespace
    s = re.sub(r"[ \t]+", " ", s)
    s = re.sub(r"\n{3,}", "\n\n", s)
    return s.strip()[:limit]

def choose_invoice_chunk(text: str, window_chars=2200) -> str:
    anchors = [
        r"invoice\s*(number|#|no\.?)", r"\binvoice\b", r"\bremit\s*to\b",
        r"\bbill\s*to\b", r"\bdue\s*date\b", r"\btotal\b", r"\bamount\b"
    ]
    low = text.lower()
    for pat in anchors:
        m = re.search(pat, low, flags=re.IGNORECASE)
        if m:
            i = m.start()
            start = max(0, i - window_chars // 2)
            end = min(len(text), start + window_chars)
            return text[start:end]
    return text[:window_chars]

def extract_sow_mainline(text: str) -> str:
    txt = " ".join(text.split())
    patterns = [
        r"(Statement of Work.*?entered into.*?between.*?\.)",
        r"(Statement of Work.*?between.*?\.)",
        r"(This.*?Statement of Work.*?between.*?\.)",
    ]
    for pat in patterns:
        m = re.search(pat, txt, flags=re.IGNORECASE)
        if m:
            return m.group(1).strip()
    return text[:300]


In [None]:
#Prompt builders (Mistral [INST]…[/INST] format)
def wrap_inst(s: str) -> str:
    return f"<s>[INST] {s.strip()} [/INST]"

def build_invoice_prompt(doc_text: str) -> str:
    instr = (
        "You are a strict JSON extractor for INVOICE documents.\n"
        "Return ONLY one minified JSON object between <json> and </json>. No markdown, no explanations.\n"
        "Fields: title, invoice_id, bill_to, vendor, date, amount, invoice_title, "
        "supplier_information, period, terms, tax, insurance, due_date, payment_method, "
        "additional_text, line_items.\n"
        "Rules:\n"
        "- If a field is not present, use \"Missing\".\n"
        "- If invoice_title is missing, set it to \"Invoice\".\n"
        "- Use \"Remit To\" as vendor if available; if vendor name missing, use business_id or client_id.\n"
        "- Split vendor and bill_to into name, address, phone, email if possible.\n"
        "- Use additional_text for anything not covered.\n"
        "- Output MUST start with '{' and end with '}'."
    )
    return wrap_inst(
        f"{instr}\n\nDOCUMENT:\n{doc_text}\n\nPlease output:\n<json>\n{{}}\n</json>"
    )

def build_sow_prompt(mainline_text: str) -> str:
    instr = (
        "Extract contracting company and vendor from the SOW main sentence.\n"
        "Return ONLY a minified JSON object between <json> and </json> with exactly: "
        "{\"contracting_company\":..., \"vendor\":...}.\n"
        "The vendor MUST appear in the main line containing the word 'between'.\n"
        "If not present there, set vendor to \"vendor as legal contract is missing\"."
    )
    return wrap_inst(
        f"{instr}\n\nSENTENCE:\n{mainline_text}\n\nPlease output:\n<json>\n{{}}\n</json>"
    )


In [None]:
#Generation (token-sliced completion + JSON fallback)
def extract_first_json(s: str):
    # Prefer fenced <json>...</json>
    m = re.search(r"<json>\s*(\{.*?\})\s*</json>", s, flags=re.S|re.I)
    if m:
        body = m.group(1)
        try:
            return json.loads(body)
        except Exception:
            pass
    # Plain first {...}
    start = s.find("{")
    end = s.rfind("}")
    if start != -1 and end != -1 and end > start:
        try:
            return json.loads(s[start:end+1])
        except Exception:
            pass
    # Nested brace regex
    m2 = re.search(r"\{(?:[^{}]|\{[^{}]*\})*\}", s, flags=re.S)
    if m2:
        try:
            return json.loads(m2.group(0))
        except Exception:
            pass
    # Try auto-closing truncated
    m3 = re.search(r"\{.*", s, flags=re.S)
    if m3:
        blob = m3.group(0)
        opens, closes = blob.count("{"), blob.count("}")
        if opens > closes:
            blob += "}" * (opens - closes)
        try:
            return json.loads(blob)
        except Exception:
            pass
    return None

@torch.no_grad()
def generate_json(model, tokenizer, prompt_wrapped: str, max_new_tokens=768):
    enc = tokenizer(prompt_wrapped, return_tensors="pt", truncation=True, max_length=4096)
    input_ids = enc.input_ids.to(model.device)
    input_len = input_ids.shape[-1]

    out = model.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=False,              # deterministic
        temperature=0.0,
        top_p=1.0,
        top_k=0,
        repetition_penalty=1.02,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

    # Slice completion by token length
    tail_ids = out[0][input_len:]
    completion = tokenizer.decode(tail_ids, skip_special_tokens=True).strip()

    js = extract_first_json(completion)
    if js is not None:
        return js, completion
    return {"error": "Failed to parse JSON"}, completion


In [None]:
#Run inference (choose Invoice or SOW)
doc_text = clean_pdf_text(doc_text_raw)

# ----- pick one -----
DOC_TYPE = "Invoice"   # or "SOW"

if DOC_TYPE == "Invoice":
    doc_chunk = choose_invoice_chunk(doc_text)
    prompt = build_invoice_prompt(doc_chunk)
else:
    sow_mainline = extract_sow_mainline(doc_text)
    prompt = build_sow_prompt(sow_mainline)

result, raw = generate_json(model, tokenizer, prompt, max_new_tokens=MAX_NEW_TOKENS)
print("=== Parsed JSON ===")
print(json.dumps(result, indent=2))
print("\n=== Raw Model Completion (tail) ===")
print(raw[:2000])
