In [None]:

# ---------------------------------------------------------------
# 1 · Imports & configuration
# ---------------------------------------------------------------
import os, re, logging, datetime as dt, getpass
from pathlib import Path
from collections import defaultdict

import psycopg2
from psycopg2.extras import execute_values, DictCursor
from lxml import etree as ET
from tqdm.notebook import tqdm
from dotenv import load_dotenv

# --- logging set‑up -------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)-8s %(message)s",
    datefmt="%H:%M:%S",
)

# --- locate filings -------------------------------------------------------
FILINGS_DIR = Path("npx_filings")   # change if necessary

# --- SEC date / decimal helpers ------------------------------------------
DATE_FMTS    = ("%Y-%m-%d", "%m/%d/%Y", "%m-%d-%Y", "%Y%m%d")
DEC_CLEAN_RE = re.compile(r"[^\d\.\-]")
SEC_HEADER_RE = re.compile(
    r"ACCESSION\s+NUMBER:\s*(?P<acc>[^\r\n]+).*?"
    r"FILED\s+AS\s+OF\s+DATE:\s*(?P<filed>\d{8})",
    re.I | re.S,
)

XML_PARSER = ET.XMLParser(recover=True, encoding="utf-8")

# --- load credentials -----------------------------------------------------
load_dotenv()  # pulls from .env if present

# If they were not in the env, ask interactively (good for local dev)
for var in ("PGHOST", "PGPORT", "PGUSER", "PGPASSWORD", "PGDATABASE"):
    if not os.getenv(var):
        os.environ[var] = getpass.getpass(f"{var}: ") if var == "PGPASSWORD" else input(f"{var}: ")

# ---------------------------------------------------------------
# 2 · Helper functions
# ---------------------------------------------------------------

def parse_date(s: str | None):
    if not s:
        return None
    s = s.strip()
    for fmt in DATE_FMTS:
        try:
            return dt.datetime.strptime(s, fmt).date()
        except ValueError:
            continue
    logging.warning("Un‑parsable date: %s", s)
    return None


def clean_decimal(s: str | None):
    if not s:
        return None
    txt = DEC_CLEAN_RE.sub("", s)
    try:
        return float(txt) if txt else None
    except ValueError:
        return None


def get(node, xp, sl=None):
    """Return stripped text (first match) with optional max‑length."""
    res = node.xpath(xp)
    if not res:
        return ""
    txt = res[0] if isinstance(res[0], str) else res[0].text or ""
    txt = txt.strip()
    return txt[:sl] if sl else txt

# ---------------------------------------------------------------
# 3 · SEC header & XML extraction
# ---------------------------------------------------------------

def extract_sec_header(path: Path):
    head = path.read_text("utf-8", "replace")[:4000]
    m = SEC_HEADER_RE.search(head)
    return {
        "accession_number": m.group("acc").strip() if m else "",
        "date_filed": parse_date(m.group("filed")) if m else None,
    }


def xml_blocks(path: Path):
    txt = path.read_text("utf-8", "replace")
    return re.findall(r"<XML>(.*?)</XML>", txt, flags=re.I | re.S)

# ---------------------------------------------------------------
# 4 · Parsing functions  (identical business logic, patched)
# ---------------------------------------------------------------

def parse_edgar_submission(root, hdr):
    rp_name = get(root, ".//*[local-name()='reportingPerson']/*[local-name()='name']", 250)
    data = {
        # --- filer info ---
        "reporting_person_name": rp_name,
        "phone_number": get(root, ".//*[local-name()='reportingPerson']/*[local-name()='phoneNumber']", 50),
        "address_street1": get(root, ".//*[local-name()='reportingPerson']/*[local-name()='address']/*[local-name()='street1']", 250),
        "address_street2": get(root, ".//*[local-name()='reportingPerson']/*[local-name()='address']/*[local-name()='street2']", 250),
        "address_city": get(root, ".//*[local-name()='reportingPerson']/*[local-name()='address']/*[local-name()='city']", 100),
        "address_state": get(root, ".//*[local-name()='reportingPerson']/*[local-name()='address']/*[local-name()='stateOrCountry']", 100),
        "address_zip": get(root, ".//*[local-name()='reportingPerson']/*[local-name()='address']/*[local-name()='zipCode']", 30),
        # --- filing info ---
        "accession_number": hdr["accession_number"][:30],
        "cik": get(root, ".//*[local-name()='issuerCredentials']/*[local-name()='cik']", 15),
        "conformed_period": parse_date(get(root, ".//*[local-name()='periodOfReport']")),
        "date_filed": hdr["date_filed"],
        "report_type": (get(root, ".//*[local-name()='reportInfo']/*[local-name()='reportType']") or "FUND VOTING REPORT")[:100],
        "form_type": (get(root, ".//*[local-name()='submissionType']") or "N-PX")[:10],
        "sec_file_number": get(root, ".//*[local-name()='fileNumber']", 20),
        "crd_number": get(root, ".//*[local-name()='reportingCrdNumber']", 20),
        "sec_file_number_other": get(root, ".//*[local-name()='reportingSecFileNumber']", 20),
        "lei_number": get(root, ".//*[local-name()='leiNumber']", 40),
        "investment_company_type": get(root, ".//*[local-name()='investmentCompanyType']", 20),
        "confidential_treatment": "Y" if get(root, ".//*[local-name()='reportInfo']/*[local-name()='confidentialTreatment']").upper() in {"Y", "YES", "TRUE", "1"} else "N",
        "is_notice_report": "NOTICE" in get(root, ".//*[local-name()='reportInfo']/*[local-name()='reportType']").upper(),
        "explanatory_choice": "Y" if get(root, ".//*[local-name()='explanatoryInformation']/*[local-name()='explanatoryChoice']").upper() in {"Y", "YES", "TRUE", "1"} else "N",
        "other_included_managers_count": int(get(root, ".//*[local-name()='summaryPage']/*[local-name()='otherIncludedManagersCount']") or 0),
        "series_count": 0,  # will patch later
        # --- amendment ---
        "is_amendment": get(root, ".//*[local-name()='amendmentInfo']/*[local-name()='isAmendment']").upper() in {"Y", "YES", "TRUE", "1"},
        "amendment_no": (lambda v: int(v) if v and v.isdigit() else None)(get(root, ".//*[local-name()='amendmentInfo']/*[local-name()='amendmentNo']")),
        "amendment_type": get(root, ".//*[local-name()='amendmentInfo']/*[local-name()='amendmentType']", 20),
        "notice_explanation": get(root, ".//*[local-name()='reportInfo']/*[local-name()='noticeExplanation']", 200),
        # --- signature ---
        "signatory_name": get(root, ".//*[local-name()='signaturePage']/*[local-name()='txSignature']", 250),
        "signatory_name_printed": get(root, ".//*[local-name()='signaturePage']/*[local-name()='txPrintedSignature']", 250),
        "signatory_title": get(root, ".//*[local-name()='signaturePage']/*[local-name()='txTitle']", 100),
        "signatory_date": parse_date(get(root, ".//*[local-name()='signaturePage']/*[local-name()='txAsOfDate']")),
    }
    return data

# --- institutional managers ------------------------------------------------

def parse_institutional_managers(root):
    out = []
    mgr_nodes = root.xpath(
        ".//*[local-name()='otherManagers2']//*[local-name()='investmentManagers']"
    )
    if not mgr_nodes:
        mgr_nodes = root.xpath(
            ".//*[local-name()='otherManager']"
        )

    for mn in mgr_nodes:
        row = {
            "serial_no": None,
            "name": "",
            "form13f_number": "",
            "crd_number": "",
            "sec_file_number": "",
            "lei_number": "",
        }
        # serial_no
        sn = mn.xpath(
            ".//*[local-name()='serialNo']/text()"
        )
        if sn and sn[0].isdigit():
            row["serial_no"] = int(sn[0])

        # manager name
        row["name"] = (
            get(mn, ".//*[local-name()='managerName']", 150)
            or get(mn, ".//*[local-name()='name']", 250)
        )

        # 13F / ICA number (patched logic kept)
        ica = mn.xpath(
            ".//*[local-name()='icaOr13FFileNumber']/text()"
        )
        if ica:
            row["form13f_number"] = ica[0].strip()[:17]
        else:
            f13 = mn.xpath(
                ".//*[local-name()='form13FFileNumber']/text()"
            )
            if f13:
                row["form13f_number"] = f13[0].strip()[:20]

        # CRD
        crd = mn.xpath(
            ".//*[local-name()='crdNumber']/text()"
        )
        if crd:
            row["crd_number"] = crd[0].strip()[:20]

        # otherFileNumber / secFileNumber
        ofn = mn.xpath(
            ".//*[local-name()='otherFileNumber']/text()"
        )
        if ofn:
            row["sec_file_number"] = ofn[0].strip()[:17]
        else:
            sfn = mn.xpath(
                ".//*[local-name()='secFileNumber']/text()"
            )
            if sfn:
                row["sec_file_number"] = sfn[0].strip()[:20]

        # LEI
        lei = mn.xpath(
            ".//*[local-name()='leiNumberOM']/text()"
        ) or mn.xpath(
            ".//*[local-name()='leiNumber']/text()"
        )
        if lei:
            row["lei_number"] = lei[0].strip()[:40]

        out.append(row)
    return out

# --- series ---------------------------------------------------------------

def parse_series_info(root):
    out = []
    for s in root.xpath(".//*[local-name()='seriesReports']"):
        out.append(
            {
                "series_code": get(s, ".//*[local-name()='idOfSeries']", 25),
                "series_name": get(s, ".//*[local-name()='nameOfSeries']", 250),
                "series_lei": get(s, ".//*[local-name()='leiOfSeries']", 40),
            }
        )
    return out

# --- proxy vote tables ----------------------------------------------------

def parse_proxy_tables(pvt_node):
    """Yield (vote_row, cats, mgr_serials, series_code) tuples."""
    for pt in pvt_node.xpath(".//*[local-name()='proxyTable']"):
        vote = {
            "issuer_name": get(pt, ".//*[local-name()='issuerName']", 250),
            "cusip": get(pt, ".//*[local-name()='cusip']", 30),
            "isin": get(pt, ".//*[local-name()='isin']", 30),
            "figi": get(pt, ".//*[local-name()='figi']", 30),
            "meeting_date": parse_date(
                get(pt, ".//*[local-name()='meetingDate']")
            ),
            "vote_description": get(pt, ".//*[local-name()='voteDescription']"),
            "proposed_by": get(pt, ".//*[local-name()='voteSource']", 20),
            "shares_voted": clean_decimal(
                get(pt, ".//*[local-name()='sharesVoted'][1]")
            ),
            "shares_on_loan": clean_decimal(
                get(pt, ".//*[local-name()='sharesOnLoan'][1]")
            ),
            "vote_cast": None,
            "vote_cast_shares": None,
            "management_rec": None,
            "other_notes": None,
        }
        vr = pt.xpath(".//*[local-name()='voteRecord']")
        if vr:
            vote["vote_cast"] = get(vr[0], ".//*[local-name()='howVoted']", 50)
            vote["vote_cast_shares"] = clean_decimal(
                get(vr[0], ".//*[local-name()='sharesVoted']")
            )
            vote["management_rec"] = get(
                vr[0], ".//*[local-name()='managementRecommendation']", 50
            )
            if len(vr) > 1:
                vote["other_notes"] = f"{len(vr)} voteRecord tags found"

        cats = [
            c.strip()[:100]
            for c in pt.xpath(
                ".//*[local-name()='voteCategories']//*[local-name()='categoryType']/text()"
            )
        ]
        mgr_serials = []
        for om in pt.xpath(
            ".//*[local-name()='voteManager']//*[local-name()='otherManager']/text()"
        ):
            try:
                mgr_serials.append(int(om.strip()))
            except ValueError:
                pass
        series_code = get(pt, ".//*[local-name()='voteSeries']", 25)
        yield vote, cats, mgr_serials, series_code

# ---------------------------------------------------------------
# 5 · Database helpers
# ---------------------------------------------------------------

def pg_conn():
    return psycopg2.connect(cursor_factory=DictCursor)


def upsert_category(cur, cache, cat_type):
    if cat_type in cache:
        return cache[cat_type]
    cur.execute(
        """
        INSERT INTO matter_category (category_type)
        VALUES (%s)
        ON CONFLICT (category_type) DO UPDATE SET category_type = EXCLUDED.category_type
        RETURNING category_id
        """,
        (cat_type,),
    )
    cid = cur.fetchone()[0]
    cache[cat_type] = cid
    return cid

# ---------------------------------------------------------------
# 6 · File loader  (patched version)
# ---------------------------------------------------------------

def load_filing(path: Path, conn, cat_cache):
    """Parse one filing; commit if at least one fragment is valid."""
    hdr    = extract_sec_header(path)
    frags  = xml_blocks(path)
    if not frags:
        logging.info("%s – legacy or non‑XML filing skipped", path.name)
        return

    current_form_id = None                 # last successful form in this file

    with conn.cursor() as cur:             # one transaction per file
        form_map, mgr_map, ser_map = {}, defaultdict(dict), defaultdict(dict)

        for frag in frags:
            # --- ignore empty or binary fragments --------------------------
            if not frag.strip():
                continue
            if frag.lstrip()[:5] not in ("<edga", "<?xml", "<DOC", "<EDGA"):
                logging.info("%s – non‑XML <XML> fragment skipped", path.name)
                continue
            try:
                root = ET.fromstring(frag.encode(), parser=XML_PARSER)
            except ET.XMLSyntaxError as e:
                logging.warning("%s – bad XML fragment skipped: %s", path.name, e)
                continue

            # === 1 · edgarSubmission ======================================
            for es in root.xpath("//*[local-name()='edgarSubmission']"):
                form_row = parse_edgar_submission(es, hdr)
                cols, vals = zip(*form_row.items())
                cur.execute(
                    f"INSERT INTO form_npx ({', '.join(cols)}) "
                    f"VALUES ({', '.join(['%s']*len(cols))}) RETURNING form_id",
                    vals,
                )
                form_id = cur.fetchone()["form_id"]
                current_form_id = form_id         # remember for orphan blocks
                form_map[id(es)] = form_id

                # --- 1a managers -----------------------------------------
                mgrs = parse_institutional_managers(es)
                if mgrs:
                    rows = [(
                        form_id, m["serial_no"], m["name"], m["form13f_number"],
                        m["crd_number"], m["sec_file_number"], m["lei_number"]
                    ) for m in mgrs]
                    mgr_ret = execute_values(cur, """
                        INSERT INTO institutional_manager
                          (form_id, serial_no, name, form13f_number,
                           crd_number, sec_file_number, lei_number)
                        VALUES %s RETURNING manager_id, serial_no""", rows, fetch=True)
                    for mid, sn in mgr_ret:
                        mgr_map[form_id][sn] = mid

                # --- 1b series -------------------------------------------
                series = parse_series_info(es)
                if series:
                    rows = [(
                        form_id, s["series_code"], s["series_name"], s["series_lei"]
                    ) for s in series]
                    ser_ret = execute_values(cur, """
                        INSERT INTO series (form_id, series_code, series_name, series_lei)
                        VALUES %s RETURNING series_id, series_code""", rows, fetch=True)
                    for sid, scode in ser_ret:
                        ser_map[form_id][scode] = sid
                    cur.execute("UPDATE form_npx SET series_count = %s WHERE form_id = %s", (len(ser_ret), form_id))

            # === 2 · proxyVoteTable ======================================
            for pvt in root.xpath("//*[local-name()='proxyVoteTable']"):
                par = pvt
                while par is not None and par.tag.split('}')[-1] != "edgarSubmission":
                    par = par.getparent()

                if par is not None:                          # normal case
                    form_id = form_map[id(par)]
                elif current_form_id is not None:            # orphan – attach to last form
                    logging.info("%s – orphan proxyVoteTable attached to form_id %s", path.name, current_form_id)
                    form_id = current_form_id
                else:                                        # nothing to attach to
                    logging.warning("%s – orphan proxyVoteTable skipped (no prior form)", path.name)
                    continue

                for vote_row, cats, mgr_serials, ser_code in parse_proxy_tables(pvt):
                    vote_row["form_id"] = form_id
                    cols, vals = zip(*vote_row.items())
                    cur.execute(
                        f"INSERT INTO proxy_voting_record ({', '.join(cols)}) "
                        f"VALUES ({', '.join(['%s']*len(cols))}) RETURNING vote_id",
                        vals,
                    )
                    vote_id = cur.fetchone()["vote_id"]

                    # categories -----------------------------------------
                    for c in cats:
                        cid = upsert_category(cur, cat_cache, c)
                        cur.execute("INSERT INTO proxy_voting_record_category (vote_id, category_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (vote_id, cid))

                    # manager bridge -------------------------------------
                    for sn in mgr_serials:
                        if sn is None:
                            continue
                        mid = mgr_map[form_id].get(sn)
                        if mid:
                            cur.execute("INSERT INTO voting_record_manager (vote_id, manager_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (vote_id, mid))

                    # series bridge --------------------------------------
                    if ser_code:
                        sid = ser_map[form_id].get(ser_code)
                        if sid:
                            cur.execute("INSERT INTO voting_record_series (vote_id, series_id) VALUES (%s, %s) ON CONFLICT DO NOTHING", (vote_id, sid))

        logging.info("%s – committed", path.name)

# ---------------------------------------------------------------
# 7 · Run over all filings
# ---------------------------------------------------------------

def run_loader():
    files = sorted(FILINGS_DIR.glob("*.txt"))
    if not files:
        print(f"No .txt filings found in {FILINGS_DIR}")
        return

    with pg_conn() as conn:
        cat_cache = {}
        for f in tqdm(files, desc="Filings"):
            try:
                load_filing(f, conn, cat_cache)
            except Exception as e:
                logging.exception("‼️  %s failed – rolled back", f.name)

run_loader()


