diff --git a/bitnet_tools/cli.py b/bitnet_tools/cli.py index 760f563..7cd9f8a 100644 --- a/bitnet_tools/cli.py +++ b/bitnet_tools/cli.py @@ -7,6 +7,7 @@ from pathlib import Path from .analysis import DataSummary, build_analysis_payload, build_analysis_payload_from_request, build_markdown_report +from .compare import compare_csv_files, result_to_json as compare_result_to_json from .doctor import collect_environment from .document_extract import extract_document_tables, table_to_analysis_request from .multi_csv import analyze_multiple_csv, build_multi_csv_markdown, result_to_json @@ -94,6 +95,11 @@ def _build_parser() -> argparse.ArgumentParser: multi_parser.add_argument("--no-cache", action="store_true", help="Disable file profile cache") multi_parser.add_argument("--workers", type=int, default=None, help="Optional worker count for parallel file profiling") + compare_parser = subparsers.add_parser("compare", help="Compare before/after CSV distributions") + compare_parser.add_argument("--before", required=True, type=Path, help="Before CSV path") + compare_parser.add_argument("--after", required=True, type=Path, help="After CSV path") + compare_parser.add_argument("--out", type=Path, default=Path("compare_result.json"), help="Where to store compare result JSON") + report_parser = subparsers.add_parser("report", help="Build markdown summary report from CSV") report_parser.add_argument("csv", type=Path, help="Input CSV path") report_parser.add_argument("--question", required=True, help="Analysis question") @@ -109,7 +115,7 @@ def _build_parser() -> argparse.ArgumentParser: def main(argv: list[str] | None = None) -> int: raw_args = list(sys.argv[1:] if argv is None else argv) - if raw_args and raw_args[0] not in {"analyze", "ui", "desktop", "doctor", "report", "multi-analyze", "-h", "--help"}: + if raw_args and raw_args[0] not in {"analyze", "ui", "desktop", "doctor", "report", "multi-analyze", "compare", "-h", "--help"}: raw_args.insert(0, "analyze") parser = _build_parser() @@ -153,6 +159,12 @@ def main(argv: list[str] | None = None) -> int: print(f"multi analysis report saved: {args.out_report}") return 0 + if args.command == "compare": + result = compare_csv_files(args.before, args.after) + args.out.write_text(compare_result_to_json(result), encoding="utf-8") + print(f"compare result saved: {args.out}") + return 0 + if args.command == "report": payload = build_analysis_payload(args.csv, args.question) summary = DataSummary(**payload["summary"]) diff --git a/bitnet_tools/compare.py b/bitnet_tools/compare.py new file mode 100644 index 0000000..353c577 --- /dev/null +++ b/bitnet_tools/compare.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import csv +import io +import json +import math +from collections import Counter +from pathlib import Path +from typing import Any + +from .versioning import build_dataset_fingerprint, save_lineage_link + +EPS = 1e-9 + + +def _read_csv_text(csv_text: str) -> tuple[list[str], list[dict[str, str]]]: + reader = csv.DictReader(io.StringIO(csv_text)) + rows = [{k: (v if v is not None else '') for k, v in row.items()} for row in reader] + return list(reader.fieldnames or []), rows + + +def _safe_float(value: str) -> float | None: + try: + v = float(str(value).strip()) + except (TypeError, ValueError): + return None + if math.isnan(v) or math.isinf(v): + return None + return v + + +def _is_numeric_column(before_rows: list[dict[str, str]], after_rows: list[dict[str, str]], col: str) -> bool: + seen = False + for row in before_rows + after_rows: + raw = str(row.get(col, '')).strip() + if not raw: + continue + seen = True + if _safe_float(raw) is None: + return False + return seen + + +def _normalize_probs(values: list[float]) -> list[float]: + total = sum(values) + if total <= 0: + return [1.0 / len(values)] * len(values) + return [max(v / total, EPS) for v in values] + + +def _psi(before_prob: list[float], after_prob: list[float]) -> float: + return sum((a - b) * math.log(a / b) for b, a in zip(before_prob, after_prob)) + + +def _js_divergence(before_prob: list[float], after_prob: list[float]) -> float: + m = [(b + a) / 2 for b, a in zip(before_prob, after_prob)] + + def _kl(p: list[float], q: list[float]) -> float: + return sum(pi * math.log(pi / qi) for pi, qi in zip(p, q)) + + return 0.5 * _kl(before_prob, m) + 0.5 * _kl(after_prob, m) + + +def _chi_square(before_counts: list[int], after_counts: list[int]) -> float: + before_total = sum(before_counts) + after_total = sum(after_counts) + if before_total == 0 or after_total == 0: + return 0.0 + score = 0.0 + for expected_raw, observed in zip(before_counts, after_counts): + expected = max((expected_raw / before_total) * after_total, EPS) + score += ((observed - expected) ** 2) / expected + return score + + +def _categorical_distribution(rows: list[dict[str, str]], col: str, categories: list[str]) -> list[int]: + counter = Counter(str(row.get(col, '')).strip() for row in rows) + return [counter.get(cat, 0) for cat in categories] + + +def _numeric_distribution(rows: list[dict[str, str]], col: str, bins: list[float]) -> list[int]: + counts = [0] * (len(bins) - 1) + for row in rows: + val = _safe_float(row.get(col, '')) + if val is None: + continue + for i in range(len(bins) - 1): + lower, upper = bins[i], bins[i + 1] + if (i < len(bins) - 2 and lower <= val < upper) or (i == len(bins) - 2 and lower <= val <= upper): + counts[i] += 1 + break + return counts + + +def _make_bins(values: list[float], num_bins: int = 10) -> list[float]: + v_min = min(values) + v_max = max(values) + if math.isclose(v_min, v_max): + return [v_min - 0.5, v_max + 0.5] + step = (v_max - v_min) / num_bins + return [v_min + (step * i) for i in range(num_bins)] + [v_max] + + +def compare_csv_texts(before_csv_text: str, after_csv_text: str, *, before_source: str = 'before.csv', after_source: str = 'after.csv') -> dict[str, Any]: + before_cols, before_rows = _read_csv_text(before_csv_text) + after_cols, after_rows = _read_csv_text(after_csv_text) + common_cols = sorted(set(before_cols) & set(after_cols)) + + metrics: dict[str, Any] = {} + for col in common_cols: + if _is_numeric_column(before_rows, after_rows, col): + before_values = [_safe_float(r.get(col, '')) for r in before_rows] + after_values = [_safe_float(r.get(col, '')) for r in after_rows] + all_values = [v for v in before_values + after_values if v is not None] + if not all_values: + continue + bins = _make_bins(all_values) + before_counts = _numeric_distribution(before_rows, col, bins) + after_counts = _numeric_distribution(after_rows, col, bins) + bucket_labels = [f'[{bins[i]:.4g}, {bins[i + 1]:.4g})' for i in range(len(bins) - 1)] + bucket_labels[-1] = bucket_labels[-1].replace(')', ']') + dist_type = 'numeric' + else: + categories = sorted({str(r.get(col, '')).strip() for r in before_rows + after_rows}) + if not categories: + continue + before_counts = _categorical_distribution(before_rows, col, categories) + after_counts = _categorical_distribution(after_rows, col, categories) + bucket_labels = categories + dist_type = 'categorical' + + before_prob = _normalize_probs(before_counts) + after_prob = _normalize_probs(after_counts) + metrics[col] = { + 'type': dist_type, + 'buckets': bucket_labels, + 'before_counts': before_counts, + 'after_counts': after_counts, + 'psi': _psi(before_prob, after_prob), + 'js_divergence': _js_divergence(before_prob, after_prob), + 'chi_square': _chi_square(before_counts, after_counts), + } + + before_version = build_dataset_fingerprint(before_csv_text, source_name=before_source) + after_version = build_dataset_fingerprint(after_csv_text, source_name=after_source) + lineage_path = save_lineage_link( + before_version, + after_version, + before_source=before_source, + after_source=after_source, + context={'common_columns': common_cols}, + ) + + return { + 'before': { + 'source_name': before_source, + 'fingerprint': before_version.fingerprint, + 'row_count': before_version.row_count, + 'column_count': before_version.column_count, + }, + 'after': { + 'source_name': after_source, + 'fingerprint': after_version.fingerprint, + 'row_count': after_version.row_count, + 'column_count': after_version.column_count, + }, + 'common_columns': common_cols, + 'column_metrics': metrics, + 'lineage_path': str(lineage_path), + } + + +def compare_csv_files(before_path: Path, after_path: Path) -> dict[str, Any]: + return compare_csv_texts( + before_path.read_text(encoding='utf-8'), + after_path.read_text(encoding='utf-8'), + before_source=before_path.name, + after_source=after_path.name, + ) + + +def result_to_json(result: dict[str, Any]) -> str: + return json.dumps(result, ensure_ascii=False, indent=2) diff --git a/bitnet_tools/versioning.py b/bitnet_tools/versioning.py new file mode 100644 index 0000000..5f19790 --- /dev/null +++ b/bitnet_tools/versioning.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from dataclasses import dataclass +import hashlib +import json +from pathlib import Path +from datetime import datetime, timezone +from typing import Any + +LINEAGE_DIR = Path('.bitnet_cache') / 'lineage' + + +@dataclass(frozen=True) +class DatasetVersion: + fingerprint: str + row_count: int + column_count: int + columns: list[str] + + +def build_dataset_fingerprint(csv_text: str, *, source_name: str = '', meta: dict[str, Any] | None = None) -> DatasetVersion: + lines = [line.rstrip() for line in csv_text.strip().splitlines() if line.strip()] + header = lines[0].split(',') if lines else [] + row_count = max(len(lines) - 1, 0) + payload = { + 'source_name': source_name, + 'columns': header, + 'row_count': row_count, + 'csv_text': '\n'.join(lines), + 'meta': meta or {}, + } + digest = hashlib.sha256(json.dumps(payload, ensure_ascii=False, sort_keys=True).encode('utf-8')).hexdigest() + return DatasetVersion( + fingerprint=digest, + row_count=row_count, + column_count=len(header), + columns=header, + ) + + +def save_lineage_link( + before: DatasetVersion, + after: DatasetVersion, + *, + before_source: str, + after_source: str, + context: dict[str, Any] | None = None, +) -> Path: + LINEAGE_DIR.mkdir(parents=True, exist_ok=True) + now = datetime.now(timezone.utc).isoformat() + record = { + 'created_at': now, + 'before': { + 'source_name': before_source, + 'fingerprint': before.fingerprint, + 'row_count': before.row_count, + 'column_count': before.column_count, + 'columns': before.columns, + }, + 'after': { + 'source_name': after_source, + 'fingerprint': after.fingerprint, + 'row_count': after.row_count, + 'column_count': after.column_count, + 'columns': after.columns, + }, + 'context': context or {}, + } + out_path = LINEAGE_DIR / f"{before.fingerprint[:12]}__{after.fingerprint[:12]}.json" + out_path.write_text(json.dumps(record, ensure_ascii=False, indent=2), encoding='utf-8') + return out_path diff --git a/bitnet_tools/web.py b/bitnet_tools/web.py index c831904..0e564bf 100644 --- a/bitnet_tools/web.py +++ b/bitnet_tools/web.py @@ -21,6 +21,7 @@ from urllib.parse import urlparse from .analysis import build_analysis_payload_from_request +from .compare import compare_csv_texts from .document_extract import extract_document_tables_from_base64, table_to_analysis_request from .multi_csv import analyze_multiple_csv from .planner import build_plan, execute_plan_from_csv_text, parse_question_to_intent @@ -488,6 +489,22 @@ def do_POST(self) -> None: result = extract_document_tables_from_base64(file_base64, source_name) return self._send_json(result.to_dict()) + if route == '/api/compare': + before_payload = payload.get('before', {}) + after_payload = payload.get('after', {}) + if not isinstance(before_payload, dict) or not isinstance(after_payload, dict): + return self._send_json(self._error_payload('before and after payloads are required'), HTTPStatus.BAD_REQUEST) + + before_name, before_text, _ = _coerce_csv_text_from_file_payload(before_payload) + after_name, after_text, _ = _coerce_csv_text_from_file_payload(after_payload) + result = compare_csv_texts( + before_text, + after_text, + before_source=before_name, + after_source=after_name, + ) + return self._send_json(result) + if route == "/api/analyze": question = str(payload.get("question", "")).strip() if not question: diff --git a/tests/test_compare.py b/tests/test_compare.py new file mode 100644 index 0000000..c252a94 --- /dev/null +++ b/tests/test_compare.py @@ -0,0 +1,54 @@ +import json +from bitnet_tools import cli +from bitnet_tools.compare import compare_csv_texts +from tests.test_web import _post_json, _run_server + + +def test_compare_same_data_has_near_zero_drift(): + csv_text = 'city,sales\nseoul,100\nbusan,200\n' + result = compare_csv_texts(csv_text, csv_text, before_source='before.csv', after_source='after.csv') + + assert result['column_metrics']['city']['psi'] == 0 + assert result['column_metrics']['sales']['js_divergence'] == 0 + assert result['lineage_path'].endswith('.json') + + +def test_compare_changed_data_has_positive_drift(): + before = 'city,sales\nseoul,100\nbusan,200\n' + after = 'city,sales\nseoul,100\nseoul,100\n' + + result = compare_csv_texts(before, after, before_source='before.csv', after_source='after.csv') + + assert result['column_metrics']['city']['psi'] > 0 + assert result['column_metrics']['city']['chi_square'] > 0 + + +def test_cli_compare_command(tmp_path): + before = tmp_path / 'before.csv' + after = tmp_path / 'after.csv' + out = tmp_path / 'compare.json' + + before.write_text('city,sales\nseoul,100\nbusan,200\n', encoding='utf-8') + after.write_text('city,sales\nseoul,100\nseoul,100\n', encoding='utf-8') + + code = cli.main(['compare', '--before', str(before), '--after', str(after), '--out', str(out)]) + + assert code == 0 + body = json.loads(out.read_text(encoding='utf-8')) + assert body['column_metrics']['city']['psi'] > 0 + + +def test_compare_api_returns_result_payload(): + server, thread = _run_server() + base = f'http://127.0.0.1:{server.server_port}' + try: + code, body = _post_json(base + '/api/compare', { + 'before': {'name': 'before.csv', 'normalized_csv_text': 'city,sales\nseoul,100\nbusan,200\n'}, + 'after': {'name': 'after.csv', 'normalized_csv_text': 'city,sales\nseoul,100\nseoul,100\n'}, + }) + assert code == 200 + assert body['column_metrics']['city']['psi'] > 0 + assert body['before']['source_name'] == 'before.csv' + finally: + server.shutdown() + thread.join(timeout=1)