参考: https://www.kaggle.com/code/raki21/simple-or-chain-sqlite-magic-free


In [1]:
import polars as pl
import rootutils

ROOT_DIR = rootutils.setup_root(".", cwd=True)

In [2]:
SUBSET_SIZE = 100
NEIGHBOR_NAME = "neighbor_"

SEED = 42

In [3]:
DATA_DIR = ROOT_DIR / "data"
INPUT_DIR = DATA_DIR / "inputs"
COMPETITION_DIR = INPUT_DIR / "uspto-explainable-ai"

# sqlite:////path/to/database.db
SQLITE_PATH = f"sqlite:///{(COMPETITION_DIR / 'uspto.db').as_posix()}"

In [4]:
test_df = pl.read_csv(COMPETITION_DIR / "test.csv")

patent_metadata_df = (
    pl.scan_parquet(COMPETITION_DIR / "patent_metadata.parquet")
    .filter(pl.col("publication_date") >= pl.date(1975, 1, 1))
    .select(["publication_number", "publication_date"])
    .with_columns(
        pl.col("publication_date").dt.year().alias("year"),
        pl.col("publication_date").dt.month().alias("month"),
    )
    .collect()
)
train_samples = (
    patent_metadata_df.sample(n=SUBSET_SIZE, with_replacement=False, seed=SEED)
    .select("publication_number")
    .unique()
    .to_numpy()
    .reshape(-1)
    .tolist()
)

nearest_neighbors_df = pl.scan_csv(COMPETITION_DIR / "nearest_neighbors.csv").filter(
    pl.col("publication_number").is_in(train_samples)
)