In [2]:
import datetime
import duckdb
import enum
import json
import ollama
import os
import pandas as pd
import pydantic
import requests
import time
import tqdm
from typing import Dict, List

In [3]:
DATABASE = "../database.db"

DROP_IF_EXISTS_QUERY = """
DROP TABLE IF EXISTS {table_name}
"""

CREATE_FROM_CSV_QUERY = """
CREATE TABLE {table_name} AS
SELECT *
FROM read_csv_auto('{filename}')
""".strip()

CREATE_TEMPLATE_BY_FORMAT = {
    "csv": CREATE_FROM_CSV_QUERY,
}


def create_table(con, config: Dict):
    table_name = config.get("name")
    if not table_name:
        raise ValueError("Must specify table name.")

    table_path = config.get("path")
    if not table_path:
        raise ValueError("Must specify table path.")

    table_format = config.get("format")
    if not table_format:
        raise ValueError("Must specify table format.")
    if table_format not in CREATE_TEMPLATE_BY_FORMAT:
        valid_formats = ", ".join(CREATE_TEMPLATE_BY_FORMAT.keys())
        error = f"Invalid format `{table_format}`. Must be one of: {valid_formats}."
        raise ValueError(error)

    drop_query = DROP_IF_EXISTS_QUERY.format(table_name=table_name)
    con.execute(drop_query)

    template = CREATE_TEMPLATE_BY_FORMAT.get(table_format)
    create_query = template.format(table_name=table_name, filename=table_path)
    con.execute(create_query)


def create_tables(database_name: str, table_configs: List[Dict]):
    with duckdb.connect(database_name) as con:
        for table_config in table_configs:
            create_table(con, table_config)

In [12]:
MODEL = "gemma3:1b"

EXPECTED_WORKLOAD_CHANGE_PROMPT = """
You are an expert in semantics and NFL football.

Read the following news report about an NFL player and extract structured data that captures what
the report thinks will happen to that player's workload in the next game.

Your output should be JSON with the following fields:
- "target_player"
  - The name of the player the report focuses on
- "expected_workload_change"
  - The expected change in workload (carries, targets) for the target player
  - Value can be: "much_higher", "somewhat_higher", "similar", "somewhat_lower", "much_lower", or "unknown"
- "reason_category"
  - Categorization of the primary reason for the expected workload change
  - Value can be one of the following:
    - "promotion" if the target player is getting higher workload due to good play
    - "demotion" if the target player is getting lower workload due to poor play
    - "teammate_injury" if the target player is getting higher workload due to a teammate's injury
    - "own_injury" if the target player is getting lower workload due to their own injury
    - "strong_opponent" if the opponent is expected to be stronger against the team
    - "weak_opponent" if the opponent is expected to be weaker against the team
    - "rumor" if based on a rumor about players, coaches, or strategy
    - "no_change" if workload is expected to be the same

News Report:
{report_content}

Output:
"""


class ExpectedWorkloadChangeEnum(enum.Enum):
    much_higher = "much_higher"
    somewhat_higher = "somewhat_higher"
    similar = "similar"
    somewhat_lower = "somewhat_lower"
    much_lower = "much_lower"
    unknown = "unknown"


class ReasonCategoryEnum(enum.Enum):
    promotion = "promotion"
    demotion = "demotion"
    teammate_injury = "teammate_injury"
    own_injury = "own_injury"
    strong_opponent = "strong_opponent"
    weak_opponent = "weak_opponent"
    rumor = "rumor"
    no_change = "no_change"


class ExpectedWorkloadChange(pydantic.BaseModel):
    target_player: str
    expected_workload_change: ExpectedWorkloadChangeEnum
    reason_category: ReasonCategoryEnum


def strip_code_block(raw: str) -> str:
    return raw.lstrip("```json").lstrip("```").rstrip("```")


def extract_expected_workload_change(report_content: str) -> Dict:
    kwargs = { "report_content": report_content }
    prompt = EXPECTED_WORKLOAD_CHANGE_PROMPT.format(**kwargs)
    res = ollama.chat(
        model=MODEL,
        messages=[{ "role": "user", "content": prompt}],
        format=ExpectedWorkloadChange.model_json_schema()
    )
    raw_message = res.message.content
    raw_json = strip_code_block(raw_message)
    data = json.loads(raw_json)
    return data


def process_report(r: Dict) -> Dict:
    report_content = r.get("description")
    expected = extract_expected_workload_change(report_content=report_content)
    value = {
        "report_id": r.get("report_id"),
        "expected_workload_change": expected.get("expected_workload_change", "unknown"),
        "reason_category": expected.get("reason_category"),
    }
    return value


def get_report_predictions(reports: List[Dict]) -> List[Dict]:
    result_iterator = map(process_report, reports)
    output = list(tqdm.tqdm(result_iterator, total=len(reports)))
    return output

In [13]:
comparable_reports = []
with duckdb.connect(DATABASE) as con:
    cur = con.sql("""
    SELECT
        c.report_id,
        r.description,
    FROM comparable_report c
    LEFT JOIN report r
        ON c.report_id = r.report_id
    LIMIT 60
    ;
    """)
    df_rows = cur.df()
    records = df_rows.to_dict(orient="records")
    comparable_reports.extend(records)

report_predictions = get_report_predictions(comparable_reports)
df_report_predictions = pd.DataFrame(report_predictions)
df_report_predictions.to_csv("../data/processed/report_prediction.csv", index=False)

100%|███████████████████████████████████████████| 60/60 [01:22<00:00,  1.37s/it]


In [14]:
create_tables(
    database_name=DATABASE,
    table_configs=[
        {
            "name": "report_prediction",
            "path": "../data/processed/report_prediction.csv",
            "format": "csv",
        },
    ]
)

In [15]:
with duckdb.connect(DATABASE) as con:
    con.sql("""
    SELECT
        count(1)
    FROM report_prediction
    ;
    """).show()

    con.sql("""
    SELECT
        report_id,
        expected_workload_change,
        reason_category,
    FROM report_prediction
    LIMIT 5
    ;
    """).show()

┌──────────┐
│ count(1) │
│  int64   │
├──────────┤
│       60 │
└──────────┘

┌───────────────────┬──────────────────────────┬─────────────────┐
│     report_id     │ expected_workload_change │ reason_category │
│      varchar      │         varchar          │     varchar     │
├───────────────────┼──────────────────────────┼─────────────────┤
│ rotoballer_211851 │ much_higher              │ strong_opponent │
│ rotoballer_211733 │ somewhat_higher          │ rumor           │
│ rotoballer_211642 │ somewhat_higher          │ own_injury      │
│ rotoballer_211619 │ much_higher              │ rumor           │
│ rotoballer_211616 │ somewhat_higher          │ strong_opponent │
└───────────────────┴──────────────────────────┴─────────────────┘

