In [1]:
import os

from utils import EXP_ROOT, load_jsonl_iteratively
data_dir = os.path.join(EXP_ROOT, 'datasets/kg-datasets/ja-0.5/eval_qa')

In [2]:
# Record and read document IDs for evaluation, 
# These documents are used as search basis for measuring token overlap
from utils import dump_jsonl, load_jsonl

eval_ids_fn = os.path.join(data_dir, "docids.jsonl")

if os.path.exists(eval_ids_fn) and len(load_jsonl(eval_ids_fn)) == 10000:
    print("Already generated docids.jsonl")
else:
    base_docids = {}
    for path in os.listdir(data_dir):
        if os.path.isfile(os.path.join(data_dir, path)):
            continue

        file_path = os.path.join(data_dir, path, "response.s0e10000.jsonl")
        base_docids[path] = set()
        if os.path.exists(file_path):
            for item in load_jsonl_iteratively(file_path):
                base_docids[path].add(item['id'][9:])
        else:
            print(f"File not found: {file_path}")

    assert base_docids["trisent-prompt"] == base_docids["monosent-prompt"]
    assert base_docids["trisent-prompt"] == base_docids["bisent-prompt"]
    assert len(base_docids["trisent-prompt"]) == 10000

base_docids = load_jsonl(eval_ids_fn)

Already generated docids.jsonl


In [3]:
# Tokenize all documents
from transformers import AutoTokenizer
from tqdm import tqdm
from utils import DATA_ROOT

token_fn = os.path.join(DATA_ROOT, "datasets/medical/ja/tokens.jsonl")
raw_fn = os.path.join(DATA_ROOT, "datasets/medical/ja/data.jsonl")
if os.path.exists(token_fn):
    print("Already generated tokens.jsonl")
    id2tokens = {} 
    for item in tqdm(load_jsonl_iteratively(token_fn)):
        id2tokens.update(item)
else:
    tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-3-7.2b")
    id2tokens = {}
    for item in tqdm(load_jsonl_iteratively(raw_fn)):
        data = item["abstract"]
        inputs = tokenizer(data, return_tensors="pt", padding=True, truncation=True).input_ids
        id2tokens[item["docid"]] = inputs.tolist()[0]
        
    dump_jsonl([{id: tokens} for id, tokens in id2tokens.items()], )

  from .autonotebook import tqdm as notebook_tqdm


Already generated tokens.jsonl


614444it [00:29, 21150.40it/s]


In [4]:
# Measure the token overlap between the base docs and the left docs
token_overlap_fn = os.path.join('./caches/token_overlap.jsonl')
if os.path.exists(token_overlap_fn):
    print("Already generated token_overlap.jsonl")
    docid2overlap = {}
    for item in tqdm(load_jsonl_iteratively(token_overlap_fn)):
        docid2overlap.update(item)
else:
    base_tokens = set()
    for item in tqdm(load_jsonl_iteratively(raw_fn)):
        if item["docid"] in base_docids:
            base_tokens.update(id2tokens[item["docid"]])
        
    docid2overlap = {}
    for docid in tqdm(id2tokens):
        if docid in base_docids:
            continue
        tokens = set(id2tokens[docid])
        overlap = len(tokens.intersection(base_tokens)) / len(tokens)
        docid2overlap[docid] = (overlap, len(tokens))
    dump_jsonl([{docid: (overlap, tokens)} for docid, (overlap, tokens) in docid2overlap.items()], token_overlap_fn)

Already generated token_overlap.jsonl


604444it [00:05, 106419.64it/s]


In [5]:
from collections import defaultdict
import json
import pickle

# Read all intructions in medical_native.jsonl to a dict by docid
instructions_fn = os.path.join(DATA_ROOT, "instructions/ja/medical_native.jsonl")
assert os.path.exists(instructions_fn), f"File not found: {instructions_fn}"

pos_pkl = "./caches/id2instructions.pkl"

if os.path.exists(pos_pkl):
    print("Already generated id2instructions.pkl")
    id2instructions = pickle.load(open(pos_pkl, "rb"))
else:
    id2instructions = defaultdict(list)

    with open(instructions_fn, "r", encoding="utf8") as f:
        pbar = tqdm(desc="Indexing", unit=" lines")
        while True:
            pos = f.tell()
            line = f.readline()
            if not line:
                break
            entry = json.loads(line)
            id2instructions[entry["docid"]].append(pos)
            pbar.update(1)
    pbar.close()
    pickle.dump(id2instructions, open(pos_pkl, "wb")) 

Already generated id2instructions.pkl


In [6]:
high_overlap_fn = os.path.join(DATA_ROOT, "instructions/ja/medical_native_high_overlap.jsonl")
low_overlap_fn = os.path.join(DATA_ROOT, "instructions/ja/medical_native_low_overlap.jsonl")
in_f = open(instructions_fn, "r", encoding="utf8")

with open(high_overlap_fn, "w", encoding="utf8") as out_f1, \
     open(low_overlap_fn, "w", encoding="utf8") as out_f2:
    for docid in tqdm(base_docids):
        assert docid in id2tokens, f"Missing tokens for {docid}"
        assert len(id2instructions[docid]) != 0, f"Missing instructions for {docid}"
        for pos in id2instructions[docid]:
            in_f.seek(pos)
            line = in_f.readline()
            out_f1.write(line)
            out_f2.write(line)
    
in_f.close()

  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:09<00:00, 1077.63it/s]


In [9]:
docs = sorted(docid2overlap.items(), key=lambda x: x[1][0], reverse=True)
sorted_docs = [(docid, overlap, tokens) for docid, (overlap, tokens) in docs if tokens >= 50]

In [14]:
sorted_docs[::-1][0:10]

[('lca@@16/3/16_188', 0.5513513513513514, 185),
 ('jasmin@@2009f/0/2009f_0_34', 0.5909090909090909, 88),
 ('jceeek@@2011/0/2011_105', 0.6, 65),
 ('jceeek@@2016/0/2016_268', 0.6, 80),
 ('jsmbe@@Annual56/Proc/Annual56_1', 0.6, 60),
 ('jfsc@@130/0/130_794', 0.6043956043956044, 91),
 ('jceeek@@2010/0/2010_0_83', 0.6119402985074627, 67),
 ('jfsc@@126/0/126_789', 0.6134453781512605, 119),
 ('toxp@@33/0/33_0_54', 0.6136363636363636, 176),
 ('pjsai@@JSAI2018/0/JSAI2018_2K104', 0.6190476190476191, 63)]

In [None]:
in_f = open(instructions_fn, "r", encoding="utf8")

with open(high_overlap_fn, "a", encoding="utf8") as out_f:
    for docid, overlap, tokens in tqdm(sorted_docs):
        for pos in id2instructions[docid]:
            in_f.seek(pos)
            line = in_f.readline()
            out_f.write(line)
in_f.close()

100%|██████████| 587673/587673 [31:19<00:00, 312.70it/s] 


In [15]:
in_f = open(instructions_fn, "r", encoding="utf8")

with open(low_overlap_fn, "a", encoding="utf8") as out_f:
    for docid, overlap, tokens in tqdm(sorted_docs[::-1]):
        for pos in id2instructions[docid]:
            in_f.seek(pos)
            line = in_f.readline()
            out_f.write(line)
in_f.close()

100%|██████████| 587673/587673 [17:03<00:00, 574.44it/s]  
