In [108]:

import argparse
import datetime
import json
import logging
import os
import re
from typing import Any, List

import numpy as np
import pandas as pd
import yaml

from llm4explore.model import non_trainable_gen, pretrain_map
from llm4explore.model.base import IdeaGenerator, IdeaMapper
from scipy.spatial.distance import euclidean

In [109]:
data_csv = "data/massw_llm4explore.tsv"
data_npz = "data/ada2_key_ideas.npz"
target_col = "key_idea"
time_col = "year"
time_split = 2023

data = pd.read_csv(
            data_csv,
            sep="\t",
            usecols=[
                target_col,
                time_col ,
            ],
        )

data = data.dropna(subset=[target_col])
targets = data[target_col]
times = data[time_col]
times_old = times[times < time_split]


npz = np.load(data_npz)
high_dim_embeddings = npz["high_dim_embeddings"]
low_dim_embeddings = npz["low_dim_embeddings"]
n_dims = low_dim_embeddings.shape[1]

targets_old = targets[times < time_split].tolist()
targets_new = targets[times >= time_split].tolist()
low_dim_embeddings_new = low_dim_embeddings[times >= time_split]
low_dim_embeddings_old = low_dim_embeddings[times < time_split]

In [110]:
print(len(targets_old))
print(len(low_dim_embeddings_new))
print(len(low_dim_embeddings_old))

144776
4635
144776


In [111]:
times_old = times[times < time_split]
print(len(times_old))
print(len(times))

144776
149416


In [112]:
from llm4explore.model.common import KNNSampler
sampler = KNNSampler(low_dim_embeddings_old)
neighbor = {}
for i, low_dim_embedding in enumerate(low_dim_embeddings_new):
    indices, dists = sampler.sample(low_dim_embedding)
    neighbor[i] = indices

Loading KNN index from cache.
KNN index initialized.


In [113]:
from tqdm import tqdm
index_to_year_old = {idx: year for idx, year in enumerate(times[times < time_split])}

nearest_neighbors = {}
rst = {}
for i, indices in tqdm(neighbor.items(), desc="Processing neighbors"):
    rsts = []
    for idx in indices:
        year = index_to_year_old[idx]
        mask = [j for j, y in enumerate(times_old) if y <= year]
        distances = []
        for j, y in enumerate(times_old):
            if y <= year:
                distance = euclidean(low_dim_embeddings_old[idx], low_dim_embeddings_old[j])
            else:
                distance = np.inf
            distances.append((j, distance)) 
        valid_indices = [dist[0] for dist in distances if 0 < dist[1] < 0.1]
        if len(valid_indices) < 2:
            valid_indices = [dist[0] for dist in sorted(distances, key=lambda x: x[1])[1:4]]
        nearest_neighbors[idx] = valid_indices
        rsts.append(valid_indices)
    rst[i] = rsts
print(rst)
print(nearest_neighbors)



Processing neighbors: 100%|██████████| 4635/4635 [4:15:29<00:00,  3.31s/it]  

{0: [[23543, 25176, 26079, 28393, 127846, 138993], [23543, 25176, 99270], [23543, 138993], [25176, 127846, 138993], [938, 99270, 138993, 144065], [16448, 23543, 25176, 127846, 138722, 139992, 141188], [131262, 131728]], 1: [[92985, 105631, 107315, 118109, 129414], [104744, 55710, 64631], [8529, 23388, 92985, 107315, 116159, 129414, 129606, 131467], [68523, 92985, 116159, 118109, 119363, 129414], [4243, 8529, 68523, 92985, 116159, 118109, 119363, 129414], [58773, 60371, 92985], [68523, 78485, 92985, 112180], [12458, 20479, 58773, 60371, 92002, 105631, 118937, 119523, 124213], [91388, 92985], [92985, 107315, 115719, 116159, 118109, 119363], [92985, 55710, 64631], [583, 91388, 92985, 107315, 115719, 116159, 129414, 129606, 131467]], 2: [[55382, 55613], [12174, 55382, 55581, 55613, 60470, 112963, 128348, 137083], [102402, 112963], [28739, 53698, 55613, 112963, 128348, 137083, 139413], [26650, 102402, 112963, 137487], [69108, 91259, 55587], [16202, 119358, 137487], [12174, 28739, 55613, 112




In [114]:
import pickle
with open('predata/pre_compute.pkl', 'wb') as f:
    pickle.dump(rst, f)

# 保存 nearest_neighbors 到文件
with open('predata/nearest_neighbors.pkl', 'wb') as f:
    pickle.dump(nearest_neighbors, f)

In [115]:
import json
import numpy as np

def convert_keys_to_str(data):
    if isinstance(data, dict):
        return {str(key): convert_keys_to_str(value) for key, value in data.items()}
    elif isinstance(data, list):
        return [convert_keys_to_str(item) for item in data]
    else:
        return data

# Assuming rst and nearest_neighbors are your dictionaries

# Save rst to a JSONL file
with open('predata/pre_compute.jsonl', 'w') as f:
    for key, value in rst.items():
        converted_data = convert_keys_to_str({key: value})
        json_line = json.dumps(converted_data)
        f.write(json_line + '\n')

# Save nearest_neighbors to a JSONL file
with open('predata/nearest_neighbors.jsonl', 'w') as f:
    for key, value in nearest_neighbors.items():
        converted_data = convert_keys_to_str({key: value})
        json_line = json.dumps(converted_data)
        f.write(json_line + '\n')
