In [None]:
import json
import os
import logging
import pandas as pd
from sqlalchemy import create_engine
from datetime import datetime

# from superset_client import SupersetClient
from rest_api_client import RestApiClient

In [None]:
class SupersetClient(RestApiClient):
    """SupersetのREST APIと連携するためのクライアント。"""

    def __init__(self, base_url: str, username: str, password: str, provider: str = "db"):
        """
        Args:
            base_url (str): SupersetのベースURL。
            username (str): ログインユーザー名。
            password (str): パスワード。
            provider (str, optional): 認証プロバイダ。通常は "db"。
        """
        super().__init__(base_url)
        self.username = username
        self.password = password
        self.provider = provider
        self._authenticate()

    def _authenticate(self):
        """ログインしてアクセストークンを取得し、ヘッダーに設定します。"""
        payload = {"username": self.username, "password": self.password, "provider": self.provider, "refresh": True}
        res = self.post("api/v1/security/login", json=payload)
        if res is None:
            raise RuntimeError("❌ Supersetへのログインに失敗しました")
        token = res.json().get("access_token")
        if not token:
            raise RuntimeError("❌ アクセストークンが取得できませんでした")
        self.set_auth_token(token)

        # CSRFトークン取得（重要！）
        csrf_resp = self.get("api/v1/security/csrf_token/")
        csrf_resp.raise_for_status()
        csrf_token = csrf_resp.json()["result"]
        self.session.headers.update({"X-CSRFToken": csrf_token})

    def create_dataset(self, database_id: int, table_name: str, schema: str = "public") -> int | None:
        """データセットを作成し、そのIDを返します。"""
        payload = {
            "database": database_id,
            "table_name": table_name,
            "schema": schema,
        }
        res = self.post("api/v1/dataset/", json=payload)
        if res is None:
            print("❌ データセットの作成に失敗しました")
            return None
        else:
            return res.json()["id"]

    def create_chart(self, dataset_id: int, chart_name: str, viz_type: str, params: dict) -> int:
        """チャートを作成し、そのIDを返します。"""
        payload = {
            "slice_name": chart_name,
            "viz_type": viz_type,
            "params": json.dumps(params),
            "datasource_id": dataset_id,
            "datasource_type": "table",
        }
        res = self.post("api/v1/chart/", json=payload)
        if res is None:
            print("❌ チャートの作成に失敗しました")
            return None
        return res.json()["id"]

    def create_dashboard(self, dashboard_title: str, chart_ids: list[int]) -> int:
        """ダッシュボードを作成し、そのIDを返します。"""
        payload = {
            "dashboard_title": dashboard_title,
            "positions": {},
            "charts": chart_ids,
        }
        res = self.post("api/v1/dashboard/", json=payload)
        if res is None:
            raise RuntimeError("❌ ダッシュボードの作成に失敗しました")
        return res.json()["id"]

    def get_database_id_by_name(self, db_name: str) -> int:
        """データベース名からIDを取得します。"""
        res = self.get("api/v1/database/")
        if res is None:
            raise RuntimeError("❌ データベース一覧の取得に失敗しました")

        databases = res.json().get("result", [])
        for db in databases:
            if db["database_name"] == db_name:
                return db["id"]

        raise ValueError(f"❌ Database '{db_name}' が見つかりませんでした")


def import_csv_to_postgres(csv_file, user, password, host, port, database, table, schema, tag):
    # Check if the CSV file exists
    if not os.path.exists(csv_file):
        print(f"CSV file '{csv_file}' does not exist. Skipping import.")
        return

    # Read the CSV file
    try:
        df = pd.read_csv(csv_file)
        # Add a timestamp column with the current time
        df["timestamp"] = datetime.now()
        # Add a tag column with the specified tag
        df["tag"] = tag
    except Exception as e:
        print(f"Error occurred while reading the CSV file: {e}")
        return

    # Connect to PostgreSQL
    try:
        engine = create_engine(f"postgresql://{user}:{password}@{host}:{port}/{database}")
        # Import data into PostgreSQL
        df.to_sql(table, con=engine, schema=schema, if_exists="append", index=False)
        print(f"Data successfully imported into table '{schema}.{table}'.")
    except Exception as e:
        print(f"Error occurred during PostgreSQL connection or data import: {e}")

In [None]:
base_url = "http://localhost:8088"
username = "admin"
password = "admin"
db_name = "PostgreSQL"

# ロギング設定
logging.basicConfig(level=logging.INFO)

client = SupersetClient(base_url, username, password)
db_id = client.get_database_id_by_name(db_name)

In [None]:
import_csv_to_postgres("../data/cloc_by_file.csv", "superset", "superset", "localhost", 5432, "superset", "cloc_by_file_test", "public", "cloc_import")
# df = pd.read_csv("../data/cloc_by_file.csv")

In [None]:
client.create_dataset(
    database_id=db_id,
    table_name="cloc_by_file_test",
    schema="public"
)

In [None]:
res = client.get("/api/v1/dataset/")
print(res.json())

In [None]:
params = {
    "query_mode": "aggregate",
    "groupby": ["language", "filename", "blank", "comment", "code"],
    "time_grain_sqla": "P1D",
    "row_limit": 1000,
    "server_page_length": 10,
    "order_desc": True,
    "table_timestamp_format": "smart_date",
    "allow_render_html": True,
    "show_cell_bars": True,
    "color_pn": True,
    "comparison_color_scheme": "Green",
    "comparison_type": "values",
}
chart_id = client.create_chart(1, "test_chart", viz_type="table", params=params)
chart_id

In [None]:
params = {
    "datasource": "1__table",
    "viz_type": "pie",
    "groupby": ["language"],
    "metric": {
        "expressionType": "SIMPLE",
        "column": {
            "column_name": "code",
            "id": 5,
            "is_certified": "false",
            "is_dttm": "false",
            "type": "BIGINT",
            "type_generic": 0,
        },
        "aggregate": "SUM",
        "datasourceWarning": "false",
        "hasCustomLabel": "false",
        "label": "SUM(code)",
        "optionName": "metric_l0sai2vwraf_uh5hnke6dc",
    },
    "row_limit": 100,
    "sort_by_metric": True,
    "color_scheme": "supersetColors",
    "show_labels_threshold": 5,
    "show_legend": True,
    "legendType": "scroll",
    "legendOrientation": "top",
    "label_type": "key_value_percent",
    "number_format": "SMART_NUMBER",
    "date_format": "smart_date",
    "show_labels": True,
    "labels_outside": True,
    "label_line": "false",
    "show_total": True,
    "outerRadius": 70,
    "donut": True,
    "innerRadius": 30,
}
chart_id = client.create_chart(1, "pie_test", viz_type="pie", params=params)
chart_id

In [None]:
res = client.get("/api/v1/chart/9")
ret = res.json()["result"]["params"]
print(ret)