In [1]:
import os
import sys

PACKAGE_DIR = "/kaggle/src"
sys.path.append(PACKAGE_DIR)
sys.path.append(os.path.join(PACKAGE_DIR, "Penguin-ML-Library"))

In [2]:
import os
import random
import warnings

import numpy as np
import yaml
from penguinml.utils.logger import get_logger, init_logger
from penguinml.utils.set_seed import seed_base

warnings.filterwarnings("ignore")

seed = 46
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)

2024-06-24 02:06:10.142294: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-24 02:06:10.211195: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-24 02:06:10.725689: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda/lib:/usr/local/lib/x86_64-linux-gnu:/usr/local/nvidia/lib:/u

In [3]:
import bz2
import gc
import hashlib
import json
import math
import multiprocessing
import random
import time
from glob import glob
from typing import List, Set

import lz4.frame
import matplotlib.pyplot as plt
import numpy as np
import plyvel
import polars as pl
from penguinml.utils.timer import Timer
from tqdm import tqdm

import whoosh_utils
from const import ALL_KEYS, CPC2TOKENS_PATH, INF, KEY2QUERY, NUM_CPU
from solver import SimulatedAnnealing
from utils import (
    calc_bytes,
    compute_ap,
    evaluate,
    load_list_bz2,
    read_bytes_in_range,
    save_list_bz2,
)


def stable_hash(obj, mod: int = 20):
    """
    NOTE: objは`json.dumps`する前のオブジェクト
    """
    obj_str = json.dumps(obj, sort_keys=True)
    hash_bytes = hashlib.sha256(obj_str.encode()).digest()
    hash_int = int.from_bytes(hash_bytes, byteorder="big")
    return hash_int % mod

Processing /kaggle/input/whoosh-wheel-2-7-4/Whoosh-2.7.4-py2.py3-none-any.whl
Whoosh is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.


[0m

In [4]:
N_SPLIT = 15

for i in range(N_SPLIT):
    !rm -rf complete-db-{i}
    !mkdir complete-db-{i}

In [5]:
files = glob("/kaggle/input/preprocess-complete/patent-data/*.bz2")
len(files)

12436350

In [6]:
batch_files = [[] for _ in range(N_SPLIT)]
for file in tqdm(files):
    center = os.path.basename(file).split(".")[0]
    split = stable_hash(center, mod=N_SPLIT)
    batch_files[split].append(file)


def process_files(args):
    files, split = args

    db = plyvel.DB(f"complete-db-{split}/db", create_if_missing=True)
    out_txt = open(f"complete-db-{split}/index.lz4", "wb")

    cursor = 0
    debug_center2txt = {}
    debug_count = 0

    for file in tqdm(files, desc=f"split={split:02d}"):
        with open(file, "rb") as f:
            data = f.read()
            decompressed_data = bz2.decompress(data).decode("utf-8")
            data = []
            for line in decompressed_data.split("\n")[:-1]:
                data.append(json.loads(line))

        center = os.path.basename(file).split(".")[0]
        assert stable_hash(center, mod=N_SPLIT) == split

        #  -> str
        txt = json.dumps(data)

        start = cursor
        compressed_chunk = lz4.frame.compress(txt.encode("utf-8"))
        out_txt.write(compressed_chunk)
        cursor += len(compressed_chunk)
        end = cursor

        db.put(center.encode(), json.dumps((start, end)).encode())

        debug_count += 1
        if debug_count % 10000 == 0:
            debug_center2txt[center] = txt

        # if debug_count > 10000:
        #     break

    out_txt.flush()
    db.close()
    return debug_center2txt


with multiprocessing.Pool(N_SPLIT) as p:
    rets = p.map(process_files, [(files, i) for i, files in enumerate(batch_files)])

debug_center2txt = {}
for dic in rets:
    debug_center2txt.update(dic)

  0%|          | 0/12436350 [00:00<?, ?it/s]

100%|██████████| 12436350/12436350 [00:29<00:00, 426811.94it/s]
split=10: 100%|██████████| 828930/828930 [3:23:03<00:00, 68.04it/s]]   
split=11: 100%|██████████| 827276/827276 [3:23:03<00:00, 67.90it/s] 
split=01: 100%|██████████| 829326/829326 [3:23:13<00:00, 68.01it/s]]
split=05: 100%|██████████| 829089/829089 [3:23:16<00:00, 67.98it/s] 
split=06: 100%|██████████| 827492/827492 [3:23:17<00:00, 67.84it/s]]
split=12: 100%|██████████| 828711/828711 [3:23:15<00:00, 67.95it/s]]
split=08: 100%|██████████| 828032/828032 [3:23:26<00:00, 67.84it/s] 
split=04: 100%|██████████| 828184/828184 [3:23:35<00:00, 67.80it/s] 
split=00: 100%|██████████| 829952/829952 [3:23:38<00:00, 67.92it/s]]
split=09: 100%|██████████| 830275/830275 [3:23:42<00:00, 67.93it/s] 
split=14: 100%|██████████| 830443/830443 [3:23:41<00:00, 67.95it/s]]
split=02: 100%|██████████| 828853/828853 [3:23:47<00:00, 67.79it/s] 
split=07: 100%|██████████| 828646/828646 [3:23:48<00:00, 67.76it/s] 
split=13: 100%|██████████| 830201/83

In [7]:
def read_lz4_in_range(file_path, start_byte, end_byte):
    with open(file_path, "rb") as f:
        f.seek(start_byte)
        compressed_content = f.read(end_byte - start_byte)
        content = lz4.frame.decompress(compressed_content)
        return content.decode("utf-8")

In [8]:
dbs = [plyvel.DB(f"complete-db-{i}/db") for i in range(N_SPLIT)]

for center, txt in tqdm(debug_center2txt.items()):
    split = stable_hash(center, mod=N_SPLIT)
    start, end = json.loads(dbs[split].get(center.encode()))
    read_txt = read_lz4_in_range(f"complete-db-{split}/index.lz4", start, end)
    assert txt == read_txt
json.loads(read_txt)

100%|██████████| 1234/1234 [00:02<00:00, 578.09it/s]


[['H01B17/16',
  'detd:contracting',
  ['US-1663007-A', 'US-1643943-A'],
  ['US-4186902-A', 'US-66215-A', 'US-10062481-B2']],
 ['H01B17/16',
  'detd:yielding',
  ['US-1663007-A', 'US-1643943-A'],
  ['US-701246-A', 'US-7180003-B2', 'US-1652835-A']],
 ['H01B17/16',
  'detd:difficulties',
  ['US-1663007-A', 'US-1643943-A'],
  ['US-1702237-A', 'US-643327-A', 'US-3899630-A', 'US-1917322-A']],
 ['H01B17/16',
  'detd:ments',
  ['US-1663007-A', 'US-1643943-A'],
  ['US-2764626-A', 'US-2144537-A']],
 ['H01B17/16',
  'detd:fracturing',
  ['US-1663007-A', 'US-1643943-A'],
  ['US-5796048-A']],
 ['H01B17/16',
  'detd:bind',
  ['US-1663007-A', 'US-1643943-A'],
  ['US-3026368-A', 'US-929132-A', 'US-1583515-A', 'US-960827-A']],
 ['H01B17/16',
  'detd:useless',
  ['US-1663007-A', 'US-1239902-A'],
  ['US-11189394-B2']],
 ['H01B17/16',
  'detd:thru',
  ['US-1663007-A', 'US-1643943-A'],
  ['US-1702237-A', 'US-1685833-A']],
 ['H01B17/16',
  'detd:bead',
  ['US-1663007-A', 'US-1643943-A'],
  ['US-3899630-A']

In [9]:
for i in range(20):
    !zip -r complete-db-{i}/db.zip complete-db-{i}/db > /dev/null