# TRAIN #

In [None]:
import os
import json
import torch
import datasets
# datasets.disable_caching()
import pandas as pd
from collections import defaultdict
from flair.models import SequenceTagger
from flair.data import Sentence
from tqdm import tqdm
import numpy as np
import pickle
from sklearn.metrics.pairwise import cosine_distances
from pathlib import Path


def mask_text(sentences: list, verbalize: bool = False):
    masked_texts = []
    for sentence in sentences:
        orig_text = sentence.to_original_text()
        new_text = ""
        prev_e_idx = 0
        for ent in sentence.to_dict()['entities']:
            s_idx, e_idx = ent['start_pos'], ent['end_pos']
            cur_entity = ent['labels'][0]['value']
            
            if verbalize:
                verbalized_entity = None
                if cur_entity == "LOC":
                    verbalized_entity = "is location"
                elif cur_entity == "TIM":
                    verbalized_entity = "is time"
                elif cur_entity == "PER":
                    verbalized_entity = "is person"
                elif cur_entity == "EVT":
                    verbalized_entity = "is event"                
                # elif cur_entity == "ORG":
                #     verbalized_entity = "is organization"
                # elif cur_entity == "DAT":
                #     verbalized_entity = "is date"
                new_text += orig_text[prev_e_idx:s_idx] + f"[{orig_text[s_idx:e_idx]} {verbalized_entity}]"
                
            else:
                # new_text += orig_text[prev_e_idx:s_idx] + f"[{orig_text[s_idx:e_idx]}]=<{ent['labels'][0]['value']}>"
                new_text += orig_text[prev_e_idx:s_idx] + f"<{cur_entity}> {orig_text[s_idx:e_idx]} </{cur_entity}>"
            prev_e_idx = e_idx
        
        new_text += orig_text[prev_e_idx:]
        
        masked_texts.append(new_text)
    return masked_texts

def read_id_text_from_file(file_name, is_json = False, has_tp = False):
    ids_, texts_, tps_ = [], [], []
    if is_json:
        data = datasets.load_dataset(
            'json',
            data_files=file_name,
        )['train']
        ids_, texts_ = data['text_id'], data['text']
        if has_tp:
            tps_ = data['is_tp']
    else:  # for .tsv 
        with open(file_name, "r") as f:
            for data in f:
                id_, text_ = data.split('\t')
                ids_.append(id_)
                texts_.append(text_)
    return ids_, texts_, tps_

def passage_dict_to_text(psg_dict: dict, simplify: bool = False):
    psg = ""
    all_key_values = []
    for k_, v_ in psg_dict.items():
        if k_ == 'ct':
            pass
        elif k_ == 'date':
            if simplify:
                pass
            else:
                all_key_values.append(f"{k_}: date")
        elif k_ == 'attendees':
            if simplify:
                all_key_values.append(f"{' , '.join(v_)}")
            else:
                all_key_values.append(f"{k_}: {' , '.join(v_)}")
        else:
            if simplify:
                all_key_values.append(f"{v_}")
            else:
                all_key_values.append(f"{k_}: {v_}")
    return ' | '.join(all_key_values)

def parse_json(json_):
    all_dict = defaultdict(list)
    idx = 0
    for _ in json_:
        if 'ct' not in _['passage'] or 'ct' not in _['query']:
            continue
        text_ = passage_dict_to_text(_['passage'], simplify=True)
        all_dict['docid'].append(idx)
        all_dict['text'].append(text_)
        all_dict['orig_text'].append(text_)
        all_dict['d2q_text'].append(passage_dict_to_text(_['passage'], simplify=False))
        all_dict['text_id'].append(_['query']['txt']) # query
        all_dict['q_ct'].append(_['query']['ct'])
        if 'title' in _['passage']: # if calendar data, then creation time (for filter) is date instead
            all_dict['p_ct'].append(_['passage']['date'])
        else:
            all_dict['p_ct'].append(_['passage']['ct'])

        if 'tagged_query' in _ and 'tagged_passage' in _:
            all_dict['tagged_query'].append(_['tagged_query'])
            all_dict['tagged_passage'].append(_['tagged_passage'])
        else:
            all_dict['tagged_query'].append(None)
            all_dict['tagged_passage'].append(None)
            
        idx += 1
        
    return all_dict 

def load_cluster_data(vector_tsv_path, mapping_pkl_path, v_dim=768):
    """
    Load clustered embeddings and ID mapping.
    """
    df = pd.read_csv(vector_tsv_path, header=None, sep='\t',
                     names=['docid', 'url', 'title', 'body', 'anchor', 'click', 'language', 'vector']).loc[:, ['docid', 'vector']]
    df.drop_duplicates('docid', inplace=True)

    # Parse vectors
    doc_ids = df['docid'].tolist()
    vectors = df['vector'].apply(lambda v: [float(x) for x in v.split('|')])
    X = np.array(vectors.tolist())

    # Load hierarchical cluster ID mapping
    with open(mapping_pkl_path, 'rb') as f:
        id_mapping = pickle.load(f)

    return doc_ids, X, id_mapping

def assign_cluster_id(new_embedding, clustered_embeddings, id_mapping, doc_ids, used_ids):
    """
    Given a new embedding, return the cluster ID of the closest existing embedding.
    """
    distances = cosine_distances([new_embedding], clustered_embeddings)[0]
    for close_idx in np.argsort(distances):  
        assert doc_ids[close_idx] == close_idx, "doc_ids[close_idx] should be equivalent to close_idx"
        close_docid_ = id_mapping[doc_ids[close_idx]] # e.g., ['1', '7', '4', '12']
        str_close_docid_ = ''.join([str(_) for _ in close_docid_])
        if str_close_docid_ in used_ids:
            pass
        else:
            return id_mapping[doc_ids[close_idx]], close_docid_

def get_exisiting_cluster_ids(test_masked_texts, use_annotation=True):
    path_prefix = dataset + "_seen" + "_masked" if use_annotation else dataset + "_seen"
    doc_ids, clustered_X, id_mapping = load_cluster_data(
        vector_tsv_path=f'data/out/{path_prefix}_doc_content_embedding_bert_512.tsv',
        mapping_pkl_path=f'IDMapping_{path_prefix}_bert_512_k9_c20_seed_7.pkl',
        v_dim=768  
    )
    """
    doc_ids = docids of the training samples (not clustered ones, just 0,1,2,... "_masked" denotes masked sentences' embeddings
    id_mapping = 0,1,2,3,.. -> clustered docid {0: ['1', '7', '4', '12'], ...}
    """
    
    from transformers import BertTokenizer, BertModel
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained("bert-base-uncased").to(f'cuda:0')
    
    test_texts_chunk = [test_masked_texts[i*20:(i+1)*20] for i in range((len(test_masked_texts) // 20) + 1)]
    
    import numpy as np
    output = []
    for test_texts_ in test_texts_chunk:
        if test_texts_ != []:
            encoded_input = tokenizer(test_texts_, max_length=512, return_tensors='pt', padding=True, truncation=True).to(f'cuda:0')
            output.extend(model(**encoded_input, return_dict=True).last_hidden_state.detach().cpu()[:, 0, :].numpy().tolist())
    
    test_texts_docids = []
    test_char_cluster_ids = []
    from tqdm import tqdm
    for test_text_emb in tqdm(output):
        cluster_id, char_cluster_ids = assign_cluster_id(test_text_emb, clustered_X, id_mapping, doc_ids, test_texts_docids) # '6153', [6, 1, 5, 3] 
        test_char_cluster_ids.append(' '.join([str(_) for _ in char_cluster_ids]))
        test_texts_docids.append(''.join([str(_) for _ in cluster_id]))
    return test_texts_docids, test_char_cluster_ids

import re
import calendar
import dateparser
from datetime import datetime, timedelta

def parse_relative_time(query, creation_time):
    settings = {'RELATIVE_BASE': creation_time}
    q = re.sub(r'[^A-Za-z0-9\s\-\:]', '', query.strip())
    query_lower = q.lower()

    weekdays = ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']
    months = ["january", "february", "march", "april", "may", "june", "july", "august", "september", "october", "november", "december"]

    seasons = {
        "spring": (3, 5),
        "summer": (6, 8),
        "fall": (9, 11),
        "autumn": (9, 11),
        "winter": (12, 2)  # spans year-end
    }

    # --- helper funcs -----------------------------------------------------
    def month_range(y: int, m: int):
        last_day = calendar.monthrange(y, m)[1]
        return datetime(y, m, 1, 0, 0, 0), datetime(y, m, last_day, 23, 59, 59)

    def weekday_last(target_idx: int):
        days_back = (creation_time.weekday() - target_idx + 7) % 7 or 7
        day = creation_time - timedelta(days=days_back)
        return day.replace(hour=0, minute=0, second=0, microsecond=0), day.replace(hour=23, minute=59, second=59)

    # ----------------------------------------------------------------------
    # 1) Explicit month modifiers  ("last June", "next March", "June")
    m_mod_match = re.fullmatch(r'(?:\b(last|next|this)\s+)?(' + '|'.join(months) + r')\b', query_lower)
    if m_mod_match:
        mod, month_word = m_mod_match.groups()
        month_idx = months.index(month_word) + 1
        year = creation_time.year
        if mod == 'last':
            year -= 1
        elif mod == 'next':
            year += 1
        start_time, end_time = month_range(year, month_idx)
        return start_time, end_time

    # 2) "beginning of <month>"  -> first half (day 1-15)
    beg_match = re.fullmatch(r'beginning of (' + '|'.join(months) + r')', query_lower)
    if beg_match:
        month_word = beg_match.group(1)
        month_idx = months.index(month_word) + 1
        year = creation_time.year
        start_time = datetime(year, month_idx, 1, 0, 0, 0)
        end_time = datetime(year, month_idx, 15, 23, 59, 59)
        return start_time, end_time

    # 3) "last <weekday>"  (entire previous week's weekday)
    for idx, wd in enumerate(weekdays):
        if query_lower == f'last {wd}':
            return weekday_last(idx)

    # ---- existing handler chain -----------------------------------------

    # Special vague patterns
    vague_period_match = re.search(r'(couple of|few)\s+(weekends?|weeks?|days?|months?|years?)\s+ago', query_lower)

    numeric_period_match = re.search(r'(\d+)\s+(day|week|month|year)s?\s+ago', query_lower)
    text_period_match = re.search(r'(one|two|three|four|five|six|seven|eight|nine|ten)\s+(day|week|month|year)s?\s+ago', query_lower)
    text2num = {
        "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
        "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10,
        "couple of": 2, "few": 3
    }

    start_time = end_time = None

    # Seasonal terms (this/last/next spring etc.)
    for season, (sm, em) in seasons.items():
        if f'this {season}' in query_lower:
            year = creation_time.year
        elif f'last {season}' in query_lower:
            year = creation_time.year - 1
        elif f'next {season}' in query_lower:
            year = creation_time.year + 1
        else:
            continue
        if season == 'winter':
            st_year = year if sm <= em else year - 1
            start_time = datetime(st_year, sm, 1, 0, 0, 0)
            end_year = st_year if em >= sm else st_year + 1
            end_time = datetime(end_year, em, calendar.monthrange(end_year, em)[1], 23, 59, 59)
        else:
            start_time = datetime(year, sm, 1, 0, 0, 0)
            end_time = datetime(year, em, calendar.monthrange(year, em)[1], 23, 59, 59)
        return start_time, end_time

    # quick explicit cases
    if 'tonight' in query_lower:
        start_time = creation_time.replace(hour=18, minute=0, second=0, microsecond=0)
        end_time = creation_time.replace(hour=23, minute=59, second=59)
    elif 'upcoming' in query_lower:
        start_time = creation_time
        end_time = creation_time + timedelta(days=30)
    elif vague_period_match:
        quant, unit = vague_period_match.groups()
        num = text2num[quant]
        if 'weekend' in unit or 'week' in unit:
            delta_days = 7 * num
        elif 'day' in unit:
            delta_days = num
        elif 'month' in unit:
            delta_days = 30 * num
        else:
            delta_days = 365 * num
        start_time = (creation_time - timedelta(days=delta_days)).replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = creation_time
    elif any(t in query_lower for t in ['yesterday', 'last night']):
        start_time = (creation_time - timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = start_time.replace(hour=23, minute=59, second=59)
    elif 'tomorrow' in query_lower:
        start_time = (creation_time + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = start_time.replace(hour=23, minute=59, second=59)
    elif 'today' in query_lower or 'this day' in query_lower:
        start_time = creation_time.replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = creation_time.replace(hour=23, minute=59, second=59)
    elif any(kw in query_lower for kw in ['last week', 'next week', 'this week']):
        wd = creation_time.weekday()
        if 'last week' in query_lower:
            start_time = (creation_time - timedelta(days=wd + 7)).replace(hour=0, minute=0, second=0, microsecond=0)
        elif 'next week' in query_lower:
            start_time = (creation_time + timedelta(days=7 - wd)).replace(hour=0, minute=0, second=0, microsecond=0)
        else:
            start_time = (creation_time - timedelta(days=wd)).replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = start_time + timedelta(days=6, hours=23, minutes=59, seconds=59)
    elif any(kw in query_lower for kw in ['last month', 'next month', 'this month']):
        first_this = creation_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
        if 'last month' in query_lower:
            last_end = first_this - timedelta(seconds=1)
            start_time = last_end.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
            end_time = last_end
        elif 'next month' in query_lower:
            next_first = (first_this + timedelta(days=32)).replace(day=1)
            start_time = next_first
            end_time = (next_first + timedelta(days=32)).replace(day=1) - timedelta(seconds=1)
        else:
            start_time = first_this
            end_time = (first_this + timedelta(days=32)).replace(day=1) - timedelta(seconds=1)
    elif any(kw in query_lower for kw in ['last year', 'next year', 'this year']):
        if 'last year' in query_lower:
            yr = creation_time.year - 1
        elif 'next year' in query_lower:
            yr = creation_time.year + 1
        else:
            yr = creation_time.year
        start_time = datetime(yr, 1, 1, 0, 0, 0)
        end_time = datetime(yr, 12, 31, 23, 59, 59)
    elif any(kw in query_lower for kw in ['last weekend', 'next weekend', 'this weekend']):
        saturday = creation_time - timedelta(days=creation_time.weekday() - 5)
        if 'last weekend' in query_lower:
            saturday -= timedelta(days=7)
        elif 'next weekend' in query_lower:
            saturday += timedelta(days=7)
        start_time = saturday.replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = start_time + timedelta(days=1, hours=23, minutes=59, seconds=59)
    elif 'afternoon' in query_lower:
        start_time = creation_time.replace(hour=12, minute=0, second=0, microsecond=0)
        end_time = creation_time.replace(hour=23, minute=59, second=59)
    elif 'morning' in query_lower:
        start_time = creation_time.replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = creation_time.replace(hour=11, minute=59, second=59)
    elif numeric_period_match:
        n, unit = int(numeric_period_match.group(1)), numeric_period_match.group(2)
        delta = {'day':1, 'week':7, 'month':30, 'year':365}[unit] * n
        start_time = (creation_time - timedelta(days=delta)).replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = creation_time
    elif text_period_match:
        word, unit = text_period_match.group(1), text_period_match.group(2)
        n = text2num[word]
        delta = {'day':1, 'week':7, 'month':30, 'year':365}[unit] * n
        start_time = (creation_time - timedelta(days=delta)).replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = creation_time
    elif 'last' in query_lower and not any(k in query_lower for k in weekdays + months):
        start_time = creation_time - timedelta(days=365)
        end_time = creation_time
    else:
        parsed = dateparser.parse(q, settings=settings)
        if not parsed:
            return None
        start_time = parsed.replace(hour=0, minute=0, second=0, microsecond=0)
        end_time = parsed.replace(hour=23, minute=59, second=59)

    return start_time, end_time


def extract_dat_from_sent(sent, ct):
    """
    returns start_time and end_time in isoformat if sent contains `any` `DAT` entity within the original text
        (1st rule) Use DAT
        (2nd rule) Use TIM 
    returns ("-", "-") otherwise

    TODO:: instead of parse_relative_time for every entity, save `valid` entity to reduce time complexity significantly
        --> pre-parse, collect `valid` entity, then check if it is in (?)
    """
    result = None
    
    # (1st rule) Check for the DAT 
    for ent in sent.to_dict()['entities']:
        ent_text, ent_type = ent['text'], ent['labels'][0]['value']
        if ent_type == 'TIM':
            result = parse_relative_time(ent_text, ct)
            if result:  # If at least one valid `DAT` entity is found, stop here
                break
    # # (2nd rule) Check for the TIM
    # if result is None:
    #     for ent in sent.to_dict()['entities']:
    #         ent_text, ent_type = ent['text'], ent['labels'][0]['value']
    #         if ent_type == 'TIM':
    #             result = parse_relative_time(ent_text, ct)
    #             if result:
    #                 break

    if result:
        return result[0].isoformat(), result[1].isoformat()
    else:
        return ("-", "-")

dataset = "syn_50k" # syn_8k, all_gen_2

In [None]:
save_file_path = f"data/synthetics/{dataset}.tsv"
fine_tuned_tagger = SequenceTagger.load('./fine-tuned-model/final-model.pt')

## Load and Parse Raw-data ##

In [None]:
with open(f"data/synthetics/{dataset}.json") as f:
    json_ = json.load(f)

print(f"num total samples: {len(json_)}")
all_dict = parse_json(json_)
print(f"num total after anomaly discards: {len(all_dict['docid'])}")
print(f"num NER samples: {sum([1 if _ is not None else 0 for _ in all_dict['tagged_query']])}")

print('--------------------------------')
print(f"id: {all_dict['docid'][-1]}")
print(all_dict['text'][-1])
print(all_dict['d2q_text'][-1])
print(all_dict['text_id'][-1])

## NER Dataset Generation (optional) ##

In [None]:
def get_word_entity_pairs(text: str, entities: list):
    entity_text = text
    if len(entities) > 0:
        for ent in entities:
            k_, v_ = next(iter(ent.items()))  
            
            # If entity exists, find and replace with entity
            # print(v_)
            if v_ in entity_text:
                # Get BIO entity for v_
                v_bio_ = []
                for idx_, split_ in enumerate(v_.split(" ")):  
                    if idx_ == 0:
                        v_bio_.append(f"B-{k_}")
                    else:
                        v_bio_.append(f"I-{k_}")
                v_bio_text = " ".join(v_bio_)

                ent_s_idx = entity_text.find(v_)
                ent_e_idx = ent_s_idx + len(v_)
                entity_text = entity_text[:ent_s_idx] + v_bio_text + entity_text[ent_e_idx:]

    word_entity_pairs = []
    for word, ent in zip(text.split(" "), entity_text.split(" ")):
        if "B-" in ent or "I-" in ent:
            trunc_ent = ent[:5] 
            if not trunc_ent[2:] in ["LOC", "PER", "EVT", "TIM"]: # ["LOC", "PER", "DAT", "ORG", "TIM"]:
                print(f"Bad entity name detected: {trunc_ent[2:]} for the word: {word}")
            else:
                word_entity_pairs.append((word, ent[:5]))
        else:
            word_entity_pairs.append((word, "O"))
    return word_entity_pairs

token_entity_pairs_lst = []
for idx_ in range(len(all_dict['docid'])):
    if all_dict['tagged_query'][idx_] is not None:
        q, p, t_q, t_p = all_dict['text_id'][idx_], all_dict['text'][idx_], all_dict['tagged_query'][idx_], all_dict['tagged_passage'][idx_]
        token_entity_pairs_lst.append(get_word_entity_pairs(text=q, entities=t_q)) 
        token_entity_pairs_lst.append(get_word_entity_pairs(text=p, entities=t_p))
        
ner_save_path = "data/synthetics/ner"
split_ratio = "7:1:2"
tot_pairs = len(token_entity_pairs_lst)
trn_idx = int(tot_pairs * int(split_ratio.split(":")[0]) / 10)
dev_idx = int(trn_idx + tot_pairs * int(split_ratio.split(":")[1]) / 10)

all_w_str = []
for pairs in token_entity_pairs_lst:
    token_entity_pairs_lst
    w_str = ""
    for pair in pairs:
        w_str += f"{pair[0]}\t{pair[1]}\n"
    w_str += "\n"
    all_w_str.append(w_str)

for type_ in ["train", "dev", "test"]:
    if type_ == 'train':
        sub_pairs = all_w_str[:trn_idx]
    elif type_ == 'dev':
        sub_pairs = all_w_str[trn_idx:dev_idx]
    else:
        sub_pairs = all_w_str[dev_idx:]
        
    with open(f"{ner_save_path}/ner_{type_}.txt", "w") as f:
        for _ in sub_pairs:
            f.write(_)

## DocTquery Dataset Generation (optional) ##
- for an independent run, need all above

In [None]:
split_ratio = "9:1"
tot_pairs = len(all_dict['text_id'])
d2q_trn_idx = int(tot_pairs * int(split_ratio.split(":")[0]) / 10)

# We do not generate docTquery dataset with masked dataset (they are for clustering)    
save_file_path_train_d2q = f"data/synthetics/{dataset}_train_d2q.json"
save_file_path_dev_d2q = f"data/synthetics/{dataset}_dev_d2q.json"
with open(save_file_path_train_d2q, 'w') as f1, open(save_file_path_dev_d2q, 'w') as f2:
    ids_, texts_ = all_dict['text_id'], all_dict['d2q_text']
    for idx_ in range(len(ids_)):
        if idx_ < d2q_trn_idx:
            f1.write(json.dumps({"text_id": ids_[idx_], "text": f"Generate a question for the following passage: {texts_[idx_]}"}) + '\n')
        else:
            f2.write(json.dumps({"text_id": ids_[idx_], "text": f"Generate a question for the following passage: {texts_[idx_]}"}) + '\n')

## Mask Texts ## 

In [None]:
pd.DataFrame.from_dict({'docid': all_dict['docid'], 'text': all_dict['text']}).to_csv(save_file_path, sep="\t", index=False, header=False) 
pd.DataFrame.from_dict({'docid': all_dict['docid'], 'text': all_dict['text']}).to_csv(f"data/synthetics/{dataset}_seen.tsv", sep="\t", index=False, header=False) 

sentences = [Sentence(text) for text in all_dict['text']]
fine_tuned_tagger.predict(sentences, mini_batch_size=512, verbose=True)
masked_texts = mask_text(sentences, verbalize=True)
pd.DataFrame.from_dict({'docid': all_dict['docid'], 'text': masked_texts}).to_csv( f"data/synthetics/{dataset}_masked.tsv", sep="\t", index=False, header=False) 
pd.DataFrame.from_dict({'docid': all_dict['docid'], 'text': masked_texts}).to_csv(f"data/synthetics/{dataset}_seen_masked.tsv", sep="\t", index=False, header=False) 

torch.cuda.empty_cache()

## DocID Assignment (K-means Clustering) ##
- old_to_new_docid
- masked_old_to_new_docid

In [None]:
!./bert.sh 1 {dataset + "_seen"} && wait
!./bert.sh 1 {dataset + "_seen" + "_masked"} && wait

In [None]:
!./kmeans.sh {dataset + "_seen"} && wait
!./kmeans.sh {dataset + "_seen" + "_masked"} && wait

In [None]:
with open(f'IDMapping_{dataset + "_seen"}_bert_512_k9_c20_seed_7.pkl', 'rb') as f:
    kmeans_qdoc_dict = pickle.load(f)
old_to_new_docid: dict = {k: ''.join([str(_) for _ in v]) for k, v in kmeans_qdoc_dict.items()}
old_to_new_docid_char: dict = {k: ' '.join([str(_) for _ in v]) for k, v in kmeans_qdoc_dict.items()} 

with open(f'IDMapping_{dataset + "_seen" + "_masked"}_bert_512_k9_c20_seed_7.pkl', 'rb') as f:
    masked_kmeans_qdoc_dict = pickle.load(f)
# print(len(masked_kmeans_qdoc_dict))  # {0: [6, 9, 3, 1], 1: [8, 9, 7, 1], 2: [2, 1, 3, 1], 3: [8, 4, 6, 1], 4: [3, 4, 9, 1], 5: [2, 2, 3, 3, 1], 6: [2, 6, 8, 8, 1], 7: [8, 9, 2, 1], 8: [2, 3, 4, 1], 9: [8, 6, 9, 1], 10: [7, 6, 2, 1],
masked_old_to_new_docid: dict = {k: ''.join([str(_) for _ in v]) for k, v in masked_kmeans_qdoc_dict.items()}
masked_old_to_new_docid_char: dict = {k: ' '.join([str(_) for _ in v]) for k, v in masked_kmeans_qdoc_dict.items()}  # ['4 6 7 5', '8 9 9 6', ...]

In [None]:
# Use for `Query Generation` task (generation.yaml)
df_ = pd.DataFrame.from_dict({'docid': list(old_to_new_docid_char.values()), 'text': all_dict['d2q_text']})
df_['docid'] = df_['docid'].astype('str')
df_.to_csv(f"data/synthetics/{dataset}_seen_clustered.tsv", sep="\t", index=False, header=False) 


# Use for `Query Generation` task (generation.yaml)
df_ = pd.DataFrame.from_dict({'docid': list(masked_old_to_new_docid_char.values()), 'text': all_dict['d2q_text']})
df_['docid'] = df_['docid'].astype('str')
df_.to_csv(f"data/synthetics/{dataset}_seen_masked_clustered.tsv", sep="\t", index=False, header=False) 


df_ = pd.DataFrame.from_dict({'docid': list(masked_old_to_new_docid_char.values()), 'text': all_dict['text']})
df_['docid'] = df_['docid'].astype('str')
df_.to_csv(f"data/synthetics/{dataset}_seen_masked_clustered_og.tsv", sep="\t", index=False, header=False) 


# df_ = pd.DataFrame.from_dict({'docid': list(old_to_new_docid_char.values()), 'text': all_dict['text_id']})
# df_['docid'] = df_['docid'].astype('str')
# df_.to_csv(f"data/synthetics/{dataset}_seen_clustered_queries.tsv", sep="\t", index=False, header=False) 


df_ = pd.DataFrame.from_dict({'docid': list(masked_old_to_new_docid_char.values()), 'text': all_dict['text_id']})
df_['docid'] = df_['docid'].astype('str')
df_.to_csv(f"data/synthetics/{dataset}_seen_masked_clustered_queries.tsv", sep="\t", index=False, header=False) 

## Query-Passage Augmentation (optional) - deprecated ##
- for an independent run, need all above

In [None]:
# import torch, itertools, numpy as np, pandas as pd
# from transformers import (AutoTokenizer,
#                           AutoModelForSequenceClassification)
# from flair.models import SequenceTagger
# from flair.data   import Sentence
# from tqdm.auto    import tqdm

# DEVICE = "cuda:0"

# # -------------------------------------------------------------------
# # 0.  Models & tokenizers
# # -------------------------------------------------------------------
# rer_tok  = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3", use_fast=True)
# rer_model = AutoModelForSequenceClassification.from_pretrained(
#               "BAAI/bge-reranker-v2-m3",
#               output_attentions=True).to(DEVICE).eval()

# # Generate Swapped Entity Query Expansion

# # -------------------------------------------------------------------
# # 1. Helper: run Flair NER and return entity list w/ char & token spans
# # -------------------------------------------------------------------
# def flair_entities(text):
#     sent = Sentence(text) # , use_tokenizer=tokenizer)        # stable offsets
#     fine_tuned_tagger.predict(sent)
#     ents = []
#     for ent in sent.get_spans("ner"):
#         ents.append({"type": ent.tag,
#                      "char_span": (ent.start_position, ent.end_position)})
#     return ents

# # -------------------------------------------------------------------
# # 2. Helper: map char span -> token indices in BGE tokenizer
# # -------------------------------------------------------------------
# def char2tok_span(text, char_span, tok_offsets):
#     """Return (tok_start, tok_end) inclusive given a HuggingFace
#        offsets_mapping list."""
#     c_start, c_end = char_span
#     for i, (s, e) in enumerate(tok_offsets):
#         if s <= c_start < e:
#             tok_start = i
#             break
#     else:
#         return None
#     for j in range(tok_start, len(tok_offsets)):
#         s, e = tok_offsets[j]
#         if e >= c_end:
#             tok_end = j
#             break
#     return tok_start, tok_end

# # -------------------------------------------------------------------
# # 3.  Augment one (query, passage) pair
# # -------------------------------------------------------------------
# def augment_pair(id_, query, passage,
#                  threshold=0.005,   # min avg‑attention to accept swap
#                  top_k=3):         # number of swaps per query entity
#     # --- NER ---
#     q_ents = flair_entities(query)
#     p_ents = flair_entities(passage)

#     # --- reranker   CLS  q  SEP  p  SEP
#     enc = rer_tok(query, passage,
#                   return_offsets_mapping=True,
#                   max_length=512,
#                   truncation=True,
#                   return_tensors="pt").to(DEVICE)
#     offsets = enc.pop("offset_mapping")

#     with torch.no_grad():
#         out  = rer_model(**enc, output_attentions=True, return_dict=True)

#     # last‑layer attn avg over heads  →  [L, L]
#     attn = out.attentions[-1].mean(dim=1).squeeze(0)    # fp16

#     # split indices
#     sep = (enc["input_ids"][0] == rer_tok.sep_token_id).nonzero(as_tuple=True)[0][0]
#     q_shift = 1                # skip CLS
#     p_shift = sep + 1          # skip CLS + query + SEP

#     offsets = offsets[0].tolist()  
#     # print(offsets) 
#     # offsets = enc["offset_mapping"][0].tolist()          # list[(s,e)]
#     swaps = []

#     # --- for every entity in query ---
#     for q_ent in q_ents:
#         if q_ent["type"] == "TIM" or q_ent["type"] == "DAT":
#             continue
#         q_tok_span = char2tok_span(query, q_ent["char_span"], offsets[q_shift:sep])
#         if not q_tok_span: continue
#         q_t0, q_t1 = (q_tok_span[0] + q_shift, q_tok_span[1] + q_shift)

#         # aggregate query‑entity attention vector
#         q_vec = attn[q_t0:q_t1+1].mean(dim=0)            # [L]

#         best = []
#         for p_ent in p_ents:
#             if p_ent["type"] != q_ent["type"]: continue
#             p_tok_span = char2tok_span(passage, p_ent["char_span"],
#                                        offsets[p_shift:-1])  # exclude last SEP
#             if not p_tok_span: continue
#             p_t0, p_t1 = (p_tok_span[0] + p_shift, p_tok_span[1] + p_shift)

#             score = q_vec[p_t0:p_t1+1].mean().item()
#             best.append((score, p_ent))

#         best = [b for b in sorted(best, key=lambda x: -x[0]) if b[0] >= threshold][:top_k]

#         # --- create swapped texts ---
#         for score, p_ent in best:
#             q_new = (query[: q_ent["char_span"][0]] +
#                      passage[p_ent["char_span"][0]: p_ent["char_span"][1]] +
#                      query[q_ent["char_span"][1]:])

#             p_new = (passage[: p_ent["char_span"][0]] +
#                      query[q_ent["char_span"][0]: q_ent["char_span"][1]] +
#                      passage[p_ent["char_span"][1]:])

#             swaps.append({"docid": id_,
#                           "q_swapped": q_new,
#                           "p_swapped": p_new,
#                           "type": q_ent["type"],
#                           "score": score})

#     return swaps         # list of dicts

In [None]:
# # -------------------------------------------------------------------
# # 4.  Run over the whole dataset [THIS VERSION IS FOR ONLY MASKED DOCIDs]
# # -------------------------------------------------------------------
# pids_, ptexts_, _ = read_id_text_from_file(f"data/synthetics/{dataset}_seen_masked_clustered_og.tsv")
# qids_, qtexts_, _ = read_id_text_from_file(f"data/synthetics/{dataset}_seen_masked_clustered_queries.tsv")
# assert pids_ == qids_, f"invalid id collisions"

# all_aug = []
# all_pairs = {'docid': pids_, 'queries': qtexts_, 'passages': ptexts_}
# for id_, q, p in tqdm(zip(all_pairs['docid'], all_pairs["queries"], all_pairs["passages"]),
#                  total=len(all_pairs), desc="Augment"):
#     all_aug.extend(augment_pair(id_, q, p))
# aug_df = pd.DataFrame(all_aug)

# sentences = [Sentence(text) for text in aug_df['q_swapped']]
# fine_tuned_tagger.predict(sentences, mini_batch_size=512, verbose=True)
# aug_df['q_swapped_mask'] = mask_text(sentences, verbalize=False)

# sentences = [Sentence(text) for text in aug_df['p_swapped']]
# fine_tuned_tagger.predict(sentences, mini_batch_size=512, verbose=True)
# aug_df['p_swapped_mask'] = mask_text(sentences, verbalize=False)

# print("generated", len(aug_df), "new swapped pairs")a

# aug_df['docid'] = aug_df['docid'].astype('str')
# with open(f"data/synthetics/{dataset}_seen_masked_swap_augmented.json", "w") as f:
#     for idx_, row_ in aug_df.iterrows():
#         f.write(json.dumps({"text_id": row_["docid"], "text": row_["q_swapped_mask"]}) + '\n')
#         f.write(json.dumps({"text_id": row_["docid"], "text": row_["p_swapped_mask"]}) + '\n')

## Query-Passage Augmentation (optional) ##
- for an independent run, need all above

In [None]:
# ------------------------------------------------------------
# helpers -----------------------------------------------------
# ------------------------------------------------------------
import re, torch
from copy import deepcopy
from typing import List, Tuple, Dict

START_TAG   = {"<PER>", "<LOC>", "<EVT>"}
END_TAG_OF  = {"<PER>": "</PER>", "<LOC>": "</LOC>", "<EVT>": "</EVT>"}

TagSpan = Tuple[str, int, int]   # (tag_type, start_tok_idx, end_tok_idx)


def _extract_entity_spans(tokens: List[str]) -> List[TagSpan]:
    """
    tokens : list of decoded tokens (already split by tokenizer)
    returns: list of (tag_type, start, end)  -- indices inclusive
            where start / end are token-indices **inside the current segment**
    """
    spans = []
    i = 0
    while i < len(tokens):
        tok = tokens[i]
        if tok in START_TAG:
            tag_type   = tok
            end_tag    = END_TAG_OF[tag_type]
            j = i + 1
            # find matching end tag
            while j < len(tokens) and tokens[j] != end_tag:
                j += 1
            if j < len(tokens):      # found end
                spans.append((tag_type, i + 1, j - 1))  # Special entity token exclusive indices
                i = j + 1
            else:
                i += 1               # malformed span; skip
        else:
            i += 1
    return spans


def _segment_tokens(tokens: List[int]) -> Tuple[List[int], List[int]]:
    """Split [CLS] query [SEP] passage [SEP] into 2 token-index sets."""
    sep_idx = tokens.index(rer_tok.sep_token_id)
    query_tok_ids    = list(range(1, sep_idx))          # exclude [CLS],[SEP]
    passage_tok_ids  = list(range(sep_idx + 1, len(tokens) - 1))
    return query_tok_ids, passage_tok_ids


def _attention_mass(att, src_ids, tgt_ids):
    """
    att : (N, N) – single attention matrix (already averaged over heads/layers)
    src_ids, tgt_ids : lists of token indices belonging to src / tgt span
    returns scalar attention mass from src->tgt
    """
    src = torch.tensor(src_ids, device=att.device)
    tgt = torch.tensor(tgt_ids, device=att.device)
    return att[src][:, tgt].sum() #  / src.size(0) this really depends 


# ------------------------------------------------------------
# new augment_pair -----------------------------------------------------------
# ------------------------------------------------------------
def augment_pair(docid: str,
                 query_str: str,
                 passage_str: str,
                 thr: float = 5e-2, # 5e-3,     # <- tune empirically; raw mass now!
                 top_layer: int = -1) -> List[Tuple[str, str, str]]:
    """
    Produce ≤|𝔈(q)| swapped pairs – one for each query entity-span that
    has at least one passage span of the *same type* with raw-mass ≥ thr.

    Returns a list of (new_docid, new_query, new_passage).
    If **no** span meets `thr`, returns [].
    """
    # 1) tokenise
    enc = rer_tok(query_str, passage_str,
                  return_tensors="pt",
                  add_special_tokens=True).to(DEVICE)

    with torch.no_grad():
        att = rer_model(**enc,
                        output_attentions=True,
                        return_dict=True
                       ).attentions[top_layer].mean(dim=1)[0]  # (N,N)

    tok_ids = enc["input_ids"][0].tolist()
    dec_tok = rer_tok.convert_ids_to_tokens(tok_ids)
    # print(dec_tok)

    q_ids, p_ids = _segment_tokens(tok_ids)
    q_spans = _extract_entity_spans([dec_tok[i] for i in q_ids])
    p_spans = _extract_entity_spans([dec_tok[i] for i in p_ids])

    # print(q_spans, p_spans)

    # map to absolute indices
    q_spans = [(t, q_ids[s], q_ids[e]) for (t, s, e) in q_spans]
    p_spans = [(t, p_ids[s], p_ids[e]) for (t, s, e) in p_spans]

    # print(dec_tok)
    # print(query_str, passage_str, q_spans, p_spans)
    
    if not q_spans or not p_spans:
        return []                       # nothing to swap

    # ------- pick best passage span for every query span -------------------
    candidates : Dict[int, Tuple[float, TagSpan]] = {}
    for idx_q, (t_q, s_q, e_q) in enumerate(q_spans):
        best = (0.0, None)              # (mass, (t,s,e))
        
        for t_p, s_p, e_p in p_spans:
            if t_q != t_p:              # require same entity type
                continue
            mass = _attention_mass(att,
                                   list(range(s_q, e_q + 1)),
                                   list(range(s_p, e_p + 1))).item()
            # print(mass)
            if mass > best[0]:
                best = (mass, (t_p, s_p, e_p))

        if best[0] >= thr:              # keep only if above threshold
            candidates[idx_q] = best

    if not candidates:
        return []                       # no span passed the threshold

    # -----------------------------------------------------------------------
    aug_pairs = []
    for k, (mass, (t_p, s_p, e_p)) in candidates.items():
        # copy original tokens each time so swaps don't interfere
        toks = dec_tok.copy()
        t_q, s_q, e_q = q_spans[k]

        span_q = toks[s_q:e_q + 1]  
        span_p = toks[s_p:e_p + 1]

        toks = toks[:s_q] + span_p + toks[e_q + 1:s_p] + span_q + toks[e_p + 1:]

        sep = toks.index(rer_tok.sep_token)
        new_query = rer_tok.convert_tokens_to_string(toks[1:sep])
        new_pass  = rer_tok.convert_tokens_to_string(toks[sep + 2:-1])  # skip <SEP> </S>

        aug_pairs.append({'docid': docid, 'q_text': new_query, 'p_text': new_pass})

    return aug_pairs


In [None]:
pids_, ptexts_, _ = read_id_text_from_file(f"data/synthetics/syn_50k_seen_masked_clustered_passages_masked_train", is_json=True)
qids_, qtexts_, _ = read_id_text_from_file(f"data/synthetics/syn_50k_seen_masked_clustered_queries_train.json", is_json=True)
assert pids_ == qids_, f"invalid id collisions"

In [None]:
import torch, itertools, numpy as np, pandas as pd
from transformers import (AutoTokenizer,
                          AutoModelForSequenceClassification)
from flair.models import SequenceTagger
from flair.data   import Sentence
from tqdm.auto    import tqdm

DEVICE = "cuda:0"

# -------------------------------------------------------------------
# 0.  Models & tokenizers
# -------------------------------------------------------------------
rer_tok  = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3", use_fast=True)
rer_model = AutoModelForSequenceClassification.from_pretrained(
              "BAAI/bge-reranker-v2-m3",
              output_attentions=True).to(DEVICE).eval()

entity_tokens = ["<PER>", "</PER>", "<LOC>", "</LOC>", "<TIM>", "</TIM>", "<EVT>", "</EVT>"]
rer_tok.add_tokens(entity_tokens)
rer_model.resize_token_embeddings(len(rer_tok))

In [None]:
num_samples = 20000000
all_pairs = {'docid': pids_[:num_samples], 'queries': qtexts_[:num_samples], 'passages': ptexts_[:num_samples]}
all_augmented = []
for docid, q, p in tqdm(zip(all_pairs["docid"],
                            all_pairs["queries"],
                            all_pairs["passages"]),
                        total=len(all_pairs),
                        desc="Augment"):
    all_augmented.extend(augment_pair(docid, q, p))

aug_df = pd.DataFrame(all_augmented)

print("generated", len(aug_df), "new swapped pairs")

aug_df['docid'] = aug_df['docid'].astype('str')
with open(f"data/synthetics/{dataset}_seen_masked_swap_augmented.json", "w") as f:
    for idx_, row_ in aug_df.iterrows():
        f.write(json.dumps({"text_id": row_["docid"], "text": row_["q_text"], "is_tp": 0}) + '\n')
        f.write(json.dumps({"text_id": row_["docid"], "text": row_["p_text"], "is_tp": 0}) + '\n')

## PostID Mask Texts ##
- after building docids, use them to adjust other files' ids accordingly
- for an independent run, it requires run upto 1.5

In [None]:
# docid with masking + masked queries [Train Set]
with open(f"data/synthetics/{dataset}_seen_masked_clustered_queries_train.json", 'w') as f:
    ids_, texts_ = list(masked_old_to_new_docid_char.values()), all_dict['text_id']
    
    texts_ = [Sentence(text) for text in texts_]
    fine_tuned_tagger.predict(texts_, mini_batch_size=512, verbose=True)
    texts_ = mask_text(texts_)
        
    for idx_ in range(len(ids_)):
        f.write(json.dumps({"text_id": ids_[idx_], "text": f"{texts_[idx_]}"}) + '\n')


with open(f"data/synthetics/{dataset}_seen_clustered_queries_train.json", 'w') as f:
    ids_, texts_ = list(old_to_new_docid_char.values()), all_dict['text_id']
    for idx_ in range(len(ids_)):
        f.write(json.dumps({"text_id": ids_[idx_], "text": f"{texts_[idx_]}"}) + '\n')


# Get passage contents for training 
with open(f"data/synthetics/{dataset}_seen_masked_clustered_og.tsv", "r") as f:
    ids_, passages_ = [], []
    for data in f:
        id_, passage_ = data.split("\t")
        ids_.append(id_)
        passages_.append(passage_)
sentences = [Sentence(text) for text in passages_]
fine_tuned_tagger.predict(sentences, mini_batch_size=512, verbose=True)
with open(f"data/synthetics/{dataset}_seen_masked_clustered_passages_masked_train", "w") as f:
    for id_, masked_passage_ in zip(ids_, mask_text(sentences)):
        f.write(json.dumps({"text_id": id_, "text": masked_passage_}) + '\n')


# with open(f"data/synthetics/{dataset}_seen_clustered.tsv", "r") as f:
#     ids_, passages_ = [], []
#     for data in f:
#         id_, passage_ = data.split("\t")
#         ids_.append(id_)
#         passages_.append(passage_)
ids_, passages_ = list(old_to_new_docid_char.values()), all_dict['text']
with open(f"data/synthetics/{dataset}_seen_clustered_passages_train", "w") as f:
    for id_, passage_ in zip(ids_, passages_):
        f.write(json.dumps({"text_id": id_, "text": passage_}) + '\n')

## Post-QG Process ##
- You have to run **GENERATION** twice before running this. Files required are:
    - [1] `{dataset}_seen_clustered.tsv`,
    - [2] `{dataset}_seen_masked_clustered.tsv`

In [None]:
d2q_file_path = f"data/synthetics/{dataset}_seen_clustered.tsv.q15.docTquery"
train_data = datasets.load_dataset(
    'json',
    data_files=d2q_file_path,
    ignore_verifications=False
)['train']
print(f"Finished loading dataset")

masked_d2q_file_path = f"data/synthetics/{dataset}_seen_masked_clustered.tsv.q15.docTquery"
masked_train_data = datasets.load_dataset(
    'json',
    data_files=masked_d2q_file_path,
)['train']
masked_text_ids = masked_train_data['text_id']
print(f"Finished loading dataset")

sentences = [Sentence(text) for text in train_data['text']]
fine_tuned_tagger.predict(sentences, mini_batch_size=768, verbose=True)

from tqdm import tqdm
# After QG, create the queries with annotation version (text_id = docid, text = text)
new_docTquery = {'text_id': train_data['text_id'], 'text': mask_text(sentences), 'masked_text_id': masked_text_ids}
with open(d2q_file_path + ".masked", 'w') as f1, open(masked_d2q_file_path + ".masked", 'w') as f2:
    for idx_ in tqdm(range(len(new_docTquery['text_id']))):
        f1.write(json.dumps({"text_id": new_docTquery['text_id'][idx_], "text": new_docTquery['text'][idx_]}) + '\n')
        f2.write(json.dumps({"text_id": new_docTquery['masked_text_id'][idx_], "text": new_docTquery['text'][idx_]}) + '\n')

## Final Train Files Concatenation ##

+ `{dataset}_seen_masked_clustered_queries_train.json` (train query annotated)                     
+ `{dataset}_seen_masked_clustered_passages_masked_train`  (train passages annotated)            
+ `{dataset}_seen_masked_clustered.tsv.q10.docTquery.masked` (generated queries annotated)         


For no Annotation train file:
+ `{dataset}_seen_clustered_queries_train.json` (train query)                                 
+ `{dataset}_seen_clustered_passages_train`  (train passages)                                 
+ `{dataset}_seen_clustered.tsv.q10.docTquery` (generated queries)                           

In [None]:
def concat_all_files(files, out):
    all_ids, all_texts = [], []
    tp_indices = None
    for file in files:
        ids_, texts_, _ = read_id_text_from_file(file, is_json=True)
        if 'passages' in file:
            tp_indices = [len(all_ids), len(all_ids) + len(ids_)]
        all_ids.extend(ids_)
        all_texts.extend(texts_)
    print(len(all_ids))

    with open(out, "w") as f:
        for idx_ in tqdm(range(len(all_ids))):
            is_tp = 1 if tp_indices[0] <= idx_ < tp_indices[1] else 0
            f.write(json.dumps({"text_id": all_ids[idx_], "text": all_texts[idx_].replace('\n', ''), "is_tp": is_tp}) + '\n')
    print(f"{out} generated successfully")
    
train_files = [
    f"data/synthetics/{dataset}_seen_masked_clustered_queries_train.json",
    f"data/synthetics/{dataset}_seen_masked_clustered_passages_masked_train",
    f"data/synthetics/{dataset}_seen_masked_clustered.tsv.q15.docTquery.masked"
]

train_files_noA = [
    f"data/synthetics/{dataset}_seen_clustered_queries_train.json",
    f"data/synthetics/{dataset}_seen_clustered_passages_train",
    f"data/synthetics/{dataset}_seen_clustered.tsv.q15.docTquery"
]

concat_all_files(files=train_files, out=f'data/synthetics/{dataset}_train_pearl')
concat_all_files(files=train_files_noA, out=f'data/synthetics/{dataset}_train_noA')

## Re-assign DocIDs with New DocIDs (optional)##

### Previous Mappings

In [None]:
with open(f'IDMapping_{dataset + "_seen"}_bert_512_k9_c20_seed_7.pkl', 'rb') as f:
    kmeans_qdoc_dict = pickle.load(f)
old_to_new_docid_char: dict = {k: ' '.join([str(_) for _ in v]) for k, v in kmeans_qdoc_dict.items()} 

with open(f'IDMapping_{dataset + "_seen" + "_masked"}_bert_512_k9_c20_seed_7.pkl', 'rb') as f:
    masked_kmeans_qdoc_dict = pickle.load(f)
masked_old_to_new_docid_char: dict = {k: ' '.join([str(_) for _ in v]) for k, v in masked_kmeans_qdoc_dict.items()}  # ['4 6 7 5', '8 9 9 6', ...]

### New Mappings

In [None]:
with open(f'IDMapping_{dataset + "_seen"}_roberta_512_k5_c20_seed_7.pkl', 'rb') as f:  # _roberta_512_k5_c20_seed_7.pkl, _bert_512_k20_c20_seed_7.pkl
    kmeans_qdoc_dict = pickle.load(f)
updated_old_to_new_docid_char: dict = {k: ' '.join([str(_) for _ in v]) for k, v in kmeans_qdoc_dict.items()} 

with open(f'IDMapping_{dataset + "_seen" + "_masked"}_roberta_512_k5_c20_seed_7.pkl', 'rb') as f:  # _roberta_512_k5_c20_seed_7.pkl, _bert_512_k20_c20_seed_7.pkl
    masked_kmeans_qdoc_dict = pickle.load(f)
updated_masked_old_to_new_docid_char: dict = {k: ' '.join([str(_) for _ in v]) for k, v in masked_kmeans_qdoc_dict.items()}  # ['4 6 7 5', '8 9 9 6', ...]

In [None]:
old_to_updated_mapping = {old_to_new_docid_char[idx_]: updated_old_to_new_docid_char[idx_] for idx_ in range(len(updated_old_to_new_docid_char))}
masked_old_to_updated_mapping = {masked_old_to_new_docid_char[idx_]: updated_masked_old_to_new_docid_char[idx_] for idx_ in range(len(updated_old_to_new_docid_char))}

### Apply New Mapping
- currently only support `masked` data

In [None]:
train_file = 'data/synthetics/syn_50k_train_perl'
qpa_file = 'data/synthetics/syn_50k_seen_masked_swap_augmented.json'
files = [
    train_file,  
    qpa_file
    ]
for file in files:
    ids, texts, tps = read_id_text_from_file(file, is_json=True, has_tp=True)
    with open(file + "_updated_rb_k5_c20", 'w') as f:
        for idx_ in range(len(ids)):
            f.write(json.dumps({"text_id": masked_old_to_updated_mapping[ids[idx_]], "text": texts[idx_], "is_tp": tps[idx_]}) + '\n')

# Test Dataset #
- for an independent run, needs:
    - raw .json files
    - all *results* from run upto ~ Train 1.5

In [None]:
#################################
use_annotation = True
#################################


unseen_raw_path = f"data/synthetics/{dataset}_unseen/raw"
for unseen_raw_file in os.listdir(unseen_raw_path):
    if os.path.isdir(os.path.join(unseen_raw_path, unseen_raw_file)):
            continue
    print(f"Start preprocessing unseen {unseen_raw_file} ...")

    # Get unseen instances' cluster ids with existing k-means cluster 
    with open(os.path.join(unseen_raw_path, unseen_raw_file), 'r') as f:
        unseen_dict_ = parse_json(json.load(f)) # text (simplified-passage), d2q_text (full-passage), text_id (q), q_ct, p_ct

        # Get masked (verbalized) passages
        passages = [Sentence(passage) for passage in unseen_dict_['text']]
        fine_tuned_tagger.predict(passages, mini_batch_size=512, verbose=True)
        unseen_dict_['masked_text'] = mask_text(passages, verbalize=True) if use_annotation else unseen_dict_['text']
        unseen_cluster_ids, unseen_cluster_char_ids = get_exisiting_cluster_ids(test_masked_texts=unseen_dict_['masked_text'], use_annotation=use_annotation)

    preprocessed_path = f"data/synthetics/{dataset}_unseen/prep/{unseen_raw_file[:-5]}" if use_annotation else f"data/synthetics/{dataset}_unseen/prep_noA/{unseen_raw_file[:-5]}"
    Path(preprocessed_path).mkdir(parents=True, exist_ok=True)

    # Create a queries dev set (**valid_file**)
    with open(os.path.join(preprocessed_path, "queries_dev.json"), 'w') as f:
        queries = [Sentence(_) for _ in unseen_dict_['text_id']]
        fine_tuned_tagger.predict(queries, mini_batch_size=512, verbose=True)
        masked_queries = mask_text(queries, verbalize=False) if use_annotation else unseen_dict_['text_id']

        # Save the masked queries with cluster ids
        for idx_ in range(len(unseen_cluster_char_ids)):
            f.write(json.dumps({"text_id": unseen_cluster_char_ids[idx_], "text": f"Question: {masked_queries[idx_]}"}) + '\n')

    
    # Create a original passage file (**passage_file**)
    df_ = pd.DataFrame.from_dict({'docid': unseen_cluster_char_ids, 'text': unseen_dict_['text']})
    df_['docid'] = df_['docid'].astype('str')
    df_.to_csv(os.path.join(preprocessed_path, "passages_og_dev.tsv"), sep="\t", index=False, header=False) 

    
    # Create a original query file (**query_file**)
    df_ = pd.DataFrame.from_dict({'docid': unseen_cluster_char_ids, 'text': unseen_dict_['text_id']})
    df_['docid'] = df_['docid'].astype('str')
    df_.to_csv(os.path.join(preprocessed_path, "queries_og_dev.tsv"), sep="\t", index=False, header=False)

    
    # Create query and passage time mapping files (**passage_time_mapping**, **query_time_mapping**) 
    with open(os.path.join(preprocessed_path, "queries_time_mapping"), "w") as f1, open(os.path.join(preprocessed_path, "passages_time_mapping"), "w") as f2: 
        for id_, p_sent, q_sent, p_ct, q_ct in tqdm(zip(unseen_cluster_char_ids, passages, queries, unseen_dict_['p_ct'], unseen_dict_['q_ct'])):
            p_dat_res, q_dat_res = extract_dat_from_sent(p_sent, datetime.fromisoformat(p_ct)), extract_dat_from_sent(q_sent, datetime.fromisoformat(q_ct))
            f1.write(json.dumps({"text_id": id_, "ct": q_ct, "t_s_time": q_dat_res[0], "t_e_time": q_dat_res[1]}) + '\n')
            f2.write(json.dumps({"text_id": id_, "ct": p_ct, "t_s_time": p_dat_res[0], "t_e_time": p_dat_res[1]}) + '\n')

In [None]:
torch.cuda.empty_cache()