In [1]:
!pip install "vllm[torch]" transformers accelerate sentencepiece huggingface_hub \
           langchain-core langchain-openai langgraph pydantic pandas pyarrow




In [2]:
# Cell 1: Imports & basic config
from __future__ import annotations

import json
from pathlib import Path
from typing import List, Optional, Literal, Dict, Any

import numpy as np
import pandas as pd
from pydantic import BaseModel, Field

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from langchain_openai import ChatOpenAI


In [3]:
# Cell 2: Paths (change as needed)
PROJECT_ROOT = Path.cwd()        # or Path("/scratch/ziv_baretto/llmserve") etc.
DATA_DIR     = PROJECT_ROOT / "data" 
OUT_DIR      = PROJECT_ROOT / "summaries"
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("DATA_DIR:", DATA_DIR)
print("OUT_DIR :", OUT_DIR)


DATA_DIR: /scratch/ziv_baretto/llmserve/data
OUT_DIR : /scratch/ziv_baretto/llmserve/summaries


In [4]:
# Cell 3: Helpers for type inference
from pandas.api.types import (
    is_numeric_dtype,
    is_integer_dtype,
    is_float_dtype,
    is_bool_dtype,
    is_datetime64_any_dtype,
)


def infer_logical_type(series: pd.Series) -> str:
    """Map pandas dtype to a logical type label."""
    s = series
    if is_bool_dtype(s):
        return "boolean"
    if is_datetime64_any_dtype(s):
        return "datetime"
    if is_integer_dtype(s):
        return "integer"
    if is_float_dtype(s):
        return "float"
    if is_numeric_dtype(s):
        return "numeric"
    # treat low-cardinality non-numeric as categorical, else text
    nunique = s.nunique(dropna=True)
    if 0 < nunique <= 50:
        return "categorical"
    return "text"


In [5]:
# Cell 4: Profile a single CSV
def profile_csv(path: Path, sample_rows: int = 5000) -> Dict[str, Any]:
    """
    Read up to `sample_rows` rows and compute basic stats for each column.
    Returns a dict that we will feed to the LLM.
    """
    print(f"Profiling {path.name} ...")
    # You can add `low_memory=False` or dtype hints if needed
    df = pd.read_csv(path, nrows=sample_rows)

    n_rows_sampled = len(df)
    cols_profile = []

    for col in df.columns:
        series = df[col]
        total = len(series)
        non_null = series.notna().sum()
        null_fraction = float(1.0 - non_null / total) if total else 0.0
        
        logical_type = infer_logical_type(series)
        physical_dtype = str(series.dtype)

        # Sample up to 5 distinct non-null values as strings
        examples = (
            series.dropna().astype(str).drop_duplicates().head(5).tolist()
        )

        n_unique = series.nunique(dropna=True)
        unique_fraction = float(n_unique / non_null) if non_null else 0.0

        cols_profile.append(
            {
                "name": col,
                "physical_dtype": physical_dtype,
                "logical_type_guess": logical_type,
                "null_fraction": null_fraction,
                "n_unique": int(n_unique),
                "unique_fraction": unique_fraction,
                "examples": examples,
            }
        )

    profile = {
        "dataset_name": path.name,
        "path": str(path),
        "n_rows_sampled": int(n_rows_sampled),
        "columns": cols_profile,
    }
    return profile


In [6]:
# Cell 5: Pydantic models for LLM output

LogicalType = Literal[
    "numeric", "integer", "float", "categorical",
    "text", "datetime", "boolean", "unknown"
]

class ColumnSummary(BaseModel):
    name: str
    physical_dtype: str
    logical_type: LogicalType
    description: str = Field(
        description="Short natural-language description of what the column represents."
    )
    nullable: bool
    null_fraction: float
    unique_fraction: float
    examples: List[str] = Field(default_factory=list)
    is_potential_key: bool = False

class DatasetSummary(BaseModel):
    dataset_name: str
    path: str
    approx_n_rows: Optional[int] = None
    columns: List[ColumnSummary]
    candidate_primary_keys: List[List[str]] = Field(
        default_factory=list,
        description="Each entry is a list of column names that could form a primary key."
    )
    notes: Optional[str] = None


In [7]:
# Cell 6: LLM client (OpenAI-compatible, backed by vLLM)

llm = ChatOpenAI(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",  # must match the model you passed to vLLM
    base_url="http://127.0.0.1:8000/v1",
    api_key="not-used",          # vLLM ignores it but ChatOpenAI requires a string
    temperature=0.2,
    max_tokens=84000,
)



In [8]:
# Cell 7: Prompt + chain

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser

parser = PydanticOutputParser(pydantic_object=DatasetSummary)

system_prompt = """
You are a meticulous data profiling assistant.

You receive a *machine-generated profile* of a tabular dataset:
- Each column has a physical dtype, a guessed logical type, null_fraction, unique_fraction, and example values.

Your job:
1. Refine the logical type for each column (choose from: numeric, integer, float,
   categorical, text, datetime, boolean, unknown).
2. Write a short, precise description for each column based on its name and examples.
3. Decide if the column is nullable.
4. Mark columns that are plausible keys (e.g., id, code, combination of state+year).
5. Propose candidate_primary_keys: each entry is a list of column names that
   could form a primary key (unique identifier for rows).
6. Add a short 'notes' field if there is anything non-obvious or suspicious. Always include a short notes string summarizing any quirks

You MUST output a JSON object that matches the DatasetSummary schema.
Output ONLY JSON. Do NOT add any explanation, heading, or surrounding text.
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt + "\n\n{format_instructions}"),
        (
            "human",
            "Here is the machine-generated profile for one dataset:\n\n{profile_json}"
        ),
    ]
).partial(format_instructions=parser.get_format_instructions())


summarizer_chain = prompt | llm | parser


In [9]:
from langchain_core.utils.json import parse_partial_json

def extract_json_block(text: str) -> str:
    """
    Extract the first top-level JSON object from a string.
    We grab from the first '{' to the last '}'.
    """
    start = text.find("{")
    end = text.rfind("}")
    if start == -1 or end == -1 or end <= start:
        raise ValueError(f"No JSON object found in LLM output:\n{text[:200]}...")
    return text[start : end + 1]


In [10]:
def summarize_profile(profile: Dict[str, Any]) -> DatasetSummary:
    profile_json = json.dumps(profile, indent=2)

    # 1) Build messages from prompt
    messages = prompt.format_messages(profile_json=profile_json)

    # 2) Call LLM
    resp = llm.invoke(messages)
    raw_text = resp.content if hasattr(resp, "content") else str(resp)

    # (optional) debug print for one dataset
    # print(raw_text)

    # 3) Extract JSON substring (strip any "Here is..." preface)
    json_str = extract_json_block(raw_text)

    # 4) Be tolerant to slight truncation using parse_partial_json
    data = parse_partial_json(json_str)

    # 5) Validate against Pydantic schema
    summary = DatasetSummary.model_validate(data)
    return summary


In [11]:
# Cell 9: Basic verification

def verify_summary(profile: Dict[str, Any], summary: DatasetSummary) -> Dict[str, Any]:
    prof_cols = {c["name"].strip().lower() for c in profile["columns"]}
    sum_cols  = {c.name.strip().lower() for c in summary.columns}

    missing_in_summary = sorted(prof_cols - sum_cols)
    extra_in_summary   = sorted(sum_cols - prof_cols)

    ok = (not missing_in_summary) and (not extra_in_summary)
    return {
        "ok": ok,
        "missing_in_summary": missing_in_summary,
        "extra_in_summary": extra_in_summary,
    }


In [12]:
# Cell 10: Main loop over datasets

def run_stage1_summarizer(
    data_dir: Path = DATA_DIR,
    out_dir: Path = OUT_DIR,
    pattern: str = "*.csv",
    sample_rows: int = 5000,
):
    paths = sorted(data_dir.glob(pattern))
    print(f"Found {len(paths)} CSVs in {data_dir}")
    for path in paths:
        print("\n" + "=" * 80)
        print(f"Dataset: {path.name}")

        profile = profile_csv(path, sample_rows=sample_rows)
        summary = summarize_profile(profile)
        check   = verify_summary(profile, summary)

        if not check["ok"]:
            print("WARNING: Column mismatch detected!")
            print("Missing in summary:", check["missing_in_summary"])
            print("Extra in summary  :", check["extra_in_summary"])
        else:
            print("Schema check: OK")

        # Save summary JSON to disk
        out_path = out_dir / f"{path.stem}.summary.json"
        with out_path.open("w") as f:
            f.write(summary.model_dump_json(indent=2))

        print(f"Wrote summary -> {out_path}")

# Actually run it
run_stage1_summarizer()


Found 3 CSVs in /scratch/ziv_baretto/llmserve/data

Dataset: All-India-Estimates-of-Area,-Production-&-Yield-of-Food-Grains.csv
Profiling All-India-Estimates-of-Area,-Production-&-Yield-of-Food-Grains.csv ...


Schema check: OK
Wrote summary -> /scratch/ziv_baretto/llmserve/summaries/All-India-Estimates-of-Area,-Production-&-Yield-of-Food-Grains.summary.json

Dataset: All-India-profile_-Crop-wise-Area.csv
Profiling All-India-profile_-Crop-wise-Area.csv ...
Schema check: OK
Wrote summary -> /scratch/ziv_baretto/llmserve/summaries/All-India-profile_-Crop-wise-Area.summary.json

Dataset: Export-of-Rice-Varieties-to-Bangladesh,-2018-19-to-2024-25.csv
Profiling Export-of-Rice-Varieties-to-Bangladesh,-2018-19-to-2024-25.csv ...
Schema check: OK
Wrote summary -> /scratch/ziv_baretto/llmserve/summaries/Export-of-Rice-Varieties-to-Bangladesh,-2018-19-to-2024-25.summary.json
