In [2]:
from poi.dataset.llm import load_prompt_completion_llm_dataset
from poi import settings

ds_dir = settings.DATASETS_DIR / "NYC" / "LLM Dataset" / "rqvae-nyc-div0.25-commit0.5-lr5e-5"

train_ds_path = ds_dir / "train_codebook.json"
val_ds_path = ds_dir / "val_codebook.json"
test_ds_path = ds_dir / "test_codebook.json"

train_ds = load_prompt_completion_llm_dataset(train_ds_path)
test_ds = load_prompt_completion_llm_dataset(test_ds_path)
val_ds = load_prompt_completion_llm_dataset(val_ds_path)



In [10]:
import re

SID_PATTERN = r"(?:<\w_\d+>)+"
USER_PATTERN = r"User_\d+"

def extract_user_sids(record):
    prompt = record["prompt"]
    completion = record["completion"]

    user = re.findall(USER_PATTERN, prompt)[0]
    prompt_sids = re.findall(SID_PATTERN, prompt)
    completion_sids = re.findall(SID_PATTERN, completion)
    
    sids = prompt_sids + completion_sids
    return user, sids


# Use non-capturing group (?:...) so findall returns the full match

prompt = train_ds[0]["prompt"]
completion = train_ds[0]["completion"]

print(prompt)

user, sids = extract_user_sids(train_ds[0])
print(user)
print(sids)


Here is a record of a user's POI accesses, your task is based on the history to predict the POI that the user is likely to access at the specified time.
User_1 visited: <a_11><b_27><c_21> at 2012-06-04 22:51:07, <a_0><b_20><c_12> at 2012-06-05 07:01:39, <a_0><b_7><c_27> at 2012-06-05 07:29:47, <a_0><b_20><c_12> at 2012-06-06 07:24:27, <a_10><b_19><c_1> at 2012-06-06 09:47:16, <a_0><b_7><c_27> at 2012-06-06 11:10:58, <a_11><b_27><c_21> at 2012-06-06 22:48:20, <a_0><b_7><c_27> at 2012-06-07 14:08:17, <a_11><b_27><c_21> at 2012-06-08 22:51:45, <a_0><b_7><c_27> at 2012-06-09 07:22:27, <a_0><b_7><c_27> at 2012-06-10 07:27:59, <a_11><b_27><c_21> at 2012-06-10 22:58:51, <a_0><b_7><c_27> at 2012-06-11 05:39:55, <a_11><b_27><c_21> at 2012-06-11 22:59:39, <a_11><b_27><c_21> at 2012-06-12 22:55:16, <a_0><b_20><c_12> at 2012-06-13 07:05:41, <a_0><b_7><c_27> at 2012-06-13 07:28:57, <a_11><b_27><c_21> at 2012-06-18 00:42:18, <a_0><b_20><c_12> at 2012-06-25 06:08:56, <a_11><b_27><c_21> at 2012-06-26 

In [12]:
def find_all_user_sids(ds):
    users = set()
    sids = set()
    for record in ds:
        user, sids_in_record = extract_user_sids(record)
        users.add(user)
        sids.update(sids_in_record)
    return users, sids

train_users, train_sids = find_all_user_sids(train_ds)
val_users, val_sids = find_all_user_sids(val_ds)
test_users, test_sids = find_all_user_sids(test_ds)



In [16]:
print("Users in test but not in train:", test_users - train_users)
print()
print("Users in val but not in train:", val_users - train_users)
print()

print("SIDs in test but not in train:", test_sids - train_sids)
print()
print("SIDs in val but not in train:", val_sids - train_sids)
print()






Users in test but not in train: {'User_597', 'User_796', 'User_941', 'User_455', 'User_112', 'User_871', 'User_515'}

Users in val but not in train: {'User_597', 'User_796', 'User_941', 'User_455', 'User_112', 'User_871', 'User_515'}

SIDs in test but not in train: {'<a_2><b_5><c_23>', '<a_16><b_28><c_27>', '<a_10><b_28><c_19>', '<a_2><b_31><c_29>', '<a_2><b_2><c_18>', '<a_2><b_31><c_1>', '<a_10><b_21><c_12>', '<a_10><b_16><c_6>', '<a_27><b_16><c_21>', '<a_29><b_8><c_14>', '<a_29><b_9><c_7>', '<a_10><b_14><c_3>', '<a_25><b_21><c_10>', '<a_10><b_10><c_5>', '<a_2><b_27><c_11>', '<a_2><b_3><c_23>', '<a_29><b_21><c_27>', '<a_10><b_31><c_31>', '<a_2><b_12><c_31>', '<a_2><b_26><c_9>'}

SIDs in val but not in train: {'<a_2><b_5><c_23>', '<a_16><b_28><c_27>', '<a_10><b_28><c_19>', '<a_2><b_31><c_29>', '<a_2><b_2><c_18>', '<a_2><b_31><c_1>', '<a_10><b_21><c_12>', '<a_10><b_16><c_6>', '<a_27><b_16><c_21>', '<a_29><b_8><c_14>', '<a_29><b_9><c_7>', '<a_10><b_14><c_3>', '<a_25><b_21><c_10>', '<a_10