In [1]:
import os
import pandas as pd
import json
import sys
import time
import gc
from collections import defaultdict

import joblib

from tqdm import tqdm

ROOT_PATH = os.getcwd()

PROJECTS_DATA_PATH = os.path.join(ROOT_PATH, 'projectsdata')
CORPUS_PATH = os.path.join(ROOT_PATH, 'CORPUS')

## Preprocessing and downloading corpus

In [None]:
# TODO: add the download functions here (from download_ORACC-JSON.ipynb)

In [108]:
""" Listing projects (for Flask app) """
projects_dict = {}

texts_in_project = {}

def find_corpusjson_folders_projects(start_path):
    data = {}

    key = 0
    
    for root, dirs, files in os.walk(start_path):
        if "corpusjson" in dirs:
            relative_path = os.path.relpath(os.path.join(root, "corpusjson"), start_path)

            relative_path_metadata = os.path.relpath(os.path.join(root, "metadata.json"), start_path)

            data[key] = {
                "corpusjson_paths": relative_path.replace("\\", "/"),
                "metadata_path": relative_path_metadata.replace("\\", "/")
            }
            

            dirs.remove("corpusjson")  # Prevent recursion to the corpusjson folder
            key += 1
    
    return data


def list_project_texts(project_name:str):
    texts_with_errors = []
    
    # Find corpusjson folders in the project:
    data = find_corpusjson_folders_projects(os.path.join(PROJECTS_DATA_PATH, project_name))
                
    project_data = {}
    
    for key in data:
        full_path = os.path.join(PROJECTS_DATA_PATH, project_name, data[key]["corpusjson_paths"])
        metadata_path = os.path.join(PROJECTS_DATA_PATH, project_name, data[key]["metadata_path"])

        files_in_folder = os.listdir(full_path)
        text_ids = []

        if len(files_in_folder) > 0:
            print(f'Found {len(files_in_folder)} files in {project_name}/{full_path[:-11]} project')
            
            for json_file_name in files_in_folder:
                text_ids.append(json_file_name[:-5])  # Remove .json extension

        with open(metadata_path, 'r', encoding='utf-8') as f:
            metadata = json.load(f)
            full_project_name = metadata['config']['name']

        pd_key = f'{project_name}/{data[key]["corpusjson_paths"][:-11]}'
        if pd_key.endswith('/'):
            pd_key = pd_key[:-1]

        project_data[pd_key] = {'full_project_name': full_project_name, 'text_id': text_ids}

    return project_data


def extract_info(filename='all_projects_jsons'):
    all_project_data = {}

    for project_name in os.listdir(PROJECTS_DATA_PATH):
        project_data = list_project_texts(project_name)

        for key, value in project_data.items():
            all_project_data[key] = value

    return all_project_data

In [109]:
data_info = extract_info()

Found 89 files in adsd/c:\Users\valek\PthonProjectsOutsideOD\ORACCIntertext\projectsdata\adsd\adart1 project
Found 162 files in adsd/c:\Users\valek\PthonProjectsOutsideOD\ORACCIntertext\projectsdata\adsd\adart2 project
Found 156 files in adsd/c:\Users\valek\PthonProjectsOutsideOD\ORACCIntertext\projectsdata\adsd\adart3 project
Found 105 files in adsd/c:\Users\valek\PthonProjectsOutsideOD\ORACCIntertext\projectsdata\adsd\adart5 project
Found 180 files in adsd/c:\Users\valek\PthonProjectsOutsideOD\ORACCIntertext\projectsdata\adsd\adart6 project
Found 2 files in aemw/c:\Users\valek\PthonProjectsOutsideOD\ORACCIntertext\projectsdata\aemw\alalakh/idrimi project
Found 305 files in aemw/c:\Users\valek\PthonProjectsOutsideOD\ORACCIntertext\projectsdata\aemw\amarna project
Found 32 files in akklove/c:\Users\valek\PthonProjectsOutsideOD\ORACCIntertext\projectsdata\akklove project
Found 175 files in ario/c:\Users\valek\PthonProjectsOutsideOD\ORACCIntertext\projectsdata\ario project
Found 160 file

In [114]:
""" Listing number of projects and number of texts. """

number_of_subprojects = 0
number_of_projects_HEAD = 0
number_of_projects_full = 0
number_of_projects_with_texts = 0
number_of_texts = 0

for project, info in data_info.items():
    number_of_projects_full += 1
    texts_in_project = 0

    print("Project:", project)

    for text_id in info['text_id']:
        number_of_texts += 1
        texts_in_project += 1

    if texts_in_project > 0:
        number_of_projects_with_texts += 1

    if '/' not in project:
        number_of_projects_HEAD += 1
    else:
        number_of_subprojects += 1

print("Number of projects (full):", number_of_projects_full)
print("Number of projects (with texts):", number_of_projects_with_texts)
print("Number of texts:", number_of_texts)
print("Number of subprojects:", number_of_subprojects)
print("Number of projects (HEAD):", number_of_projects_HEAD)


Project: adsd
Project: adsd/adart1
Project: adsd/adart2
Project: adsd/adart3
Project: adsd/adart5
Project: adsd/adart6
Project: aemw/alalakh/idrimi
Project: aemw/amarna
Project: akklove
Project: amgg
Project: ario
Project: armep
Project: arrim
Project: asbp
Project: asbp/ninmed
Project: asbp/rlasb
Project: atae
Project: atae/assur
Project: atae/burmarina
Project: atae/durkatlimmu
Project: atae/durszarrukin
Project: atae/guzana
Project: atae/huzirina
Project: atae/imgurenlil
Project: atae/kalhu
Project: atae/kunalia
Project: atae/mallanate
Project: atae/marqasu
Project: atae/nineveh
Project: atae/samal
Project: atae/szibaniba
Project: atae/tilbarsip
Project: atae/tuszhan
Project: babcity
Project: balt
Project: blms
Project: borsippa
Project: btmao
Project: btto
Project: cams/barutu
Project: cams/gkab
Project: cams/ludlul
Project: cams/selbi
Project: cams/tlab
Project: ckst
Project: cmawro
Project: cmawro/cmawr1
Project: cmawro/cmawr2
Project: cmawro/cmawr3
Project: cmawro/maqlu
Project:

In [None]:
print('{\n"projects": [')

for project, info in data_info.items():
    print('\t{ "value": "'+project+'", "label": "'+info['full_project_name']+'" },')

print('\t],\n"textsByProject": {')
for project, info in data_info.items():
    print(f'\t"{project}": [')
    i=0
    for text_id in info['text_id']:
        if i == len(info['text_id']) - 1:
            print(f'\t\t{{"value": "{project}/{text_id}"}}')
        else:
            print(f'\t\t{{"value": "{project}/{text_id}"}},')
        i += 1
    print('\t],')

print('\t}\n}')

In [49]:
def find_corpusjson_folders(start_path):
    corpusjson_paths = []
    
    for root, dirs, files in os.walk(start_path):
        if "corpusjson" in dirs:
            relative_path = os.path.relpath(os.path.join(root, "corpusjson"), start_path)
            corpusjson_paths.append(relative_path.replace("\\", "/"))
            dirs.remove("corpusjson")  # Prevent recursion to the corpusjson folder
    
    return corpusjson_paths


def extract_jsons_from_project(project_name:str):
    texts_with_errors = []
    
    # Find corpusjson folders in the project:
    corpusjson_folders = find_corpusjson_folders(os.path.join(PROJECTS_DATA_PATH, project_name))
                
    project_jsons = {}
    
    for corpusjson_folder in corpusjson_folders:
        full_path = os.path.join(PROJECTS_DATA_PATH, project_name, corpusjson_folder)
        text_id_prefix = f'{project_name}/{corpusjson_folder[:-11]}'
        files_in_folder = os.listdir(full_path)
        if len(files_in_folder) > 0:
            print(f'Found {len(files_in_folder)} files in {project_name}/{corpusjson_folder[:-11]} project')
            
            for json_file_name in files_in_folder:
                with open(os.path.join(full_path, json_file_name), 'r', encoding='utf-8') as json_file:
                    text_id = f'{text_id_prefix}/{json_file_name[:-5]}'.replace('//', '/') # in case there are no subprojects, there are double slashes --> remove them
                    #print(text_id)
                    try:
                        json_data = json.load(json_file)
                        project_jsons[text_id] = json_data
                    except:
                        texts_with_errors.append(text_id)
                        
    return project_jsons, texts_with_errors

def save_json_corpus(json_corpus:dict, save_name:str, save_path=CORPUS_PATH, compression=None):
    """ Save the ORACC corpus to a joblib file. """
    if compression:
        joblib.dump(json_corpus, os.path.join(save_path, f'{save_name}.joblib'), compress=compression)
    else:
        joblib.dump(json_corpus, os.path.join(save_path, f'{save_name}.joblib'))

def save_all_projects_jsons(filename='all_projects_jsons'):
    all_project_jsons = {}
    projects_texts_with_errors = {}

    for project_name in os.listdir(PROJECTS_DATA_PATH):
        project_jsons, texts_with_errors = extract_jsons_from_project(project_name)
        all_project_jsons[project_name] = project_jsons
        projects_texts_with_errors[project_name] = texts_with_errors

    save_json_corpus(all_project_jsons, filename)
    save_json_corpus(projects_texts_with_errors, f'{filename}_texts_with_errors')

""" Variant: try to create the corpusjson file that is smaller (ignoring many data). """
DROP_KEYS = {
    "type","implicit","id","ref","lang","role","gg","gdl_type","name",
    "ngram","delim","inst","break","ho","hc","value","subtype","label","sig"
}

def prune_obj(x, drop=DROP_KEYS):
    """Rekurzivně odstraní z dictů zadané klíče, projde i seznamy."""
    if isinstance(x, dict):
        return {k: prune_obj(v, drop) for k, v in x.items() if k not in drop}
    if isinstance(x, list):
        return [prune_obj(v, drop) for v in x]
    return x


def save_individual_project_jsons(prefix:str='project_', prune=True, compression=None):
    """ Saving individual joblib files for each project - possibly better for RAM. """

    projects_texts_with_errors = {}

    for project_name in os.listdir(PROJECTS_DATA_PATH):
        project_jsons, texts_with_errors = extract_jsons_from_project(project_name)
        projects_texts_with_errors[project_name] = texts_with_errors

        if prune:
            project_jsons = prune_obj(project_jsons)

        save_json_corpus(project_jsons, f'{prefix}{project_name}', compression=compression)

    save_json_corpus(projects_texts_with_errors, f'{prefix}texts_with_errors_for_individual_files', compression=compression)

In [51]:
save_individual_project_jsons(prefix='lz4_pruned', compression=('lz4', 3))  # Save individual projects to JSON files

# save_all_projects_jsons(mode='json')  # Save all projects to JSON files

Found 89 files in adsd/adart1 project
Found 162 files in adsd/adart2 project
Found 156 files in adsd/adart3 project
Found 105 files in adsd/adart5 project
Found 180 files in adsd/adart6 project


ValueError: LZ4 is not installed. Install it with pip: https://python-lz4.readthedocs.io/

In [None]:
# # NOTE: in case of POS variants, we want to ignore named entities/Proper Nouns. (see https://oracc.museum.upenn.edu/doc/help/languages/propernouns/index.html)
# PN_POSs = ['AN', 'CN', 'DN', 'EN', 'FN', 'GN', 'LN', 'MN', 'ON', 'PN', 'QN', 'RN', 'SN', 'TN', 'WN', 'YN']

# def load_json_corpus(json_corpus_name:str, load_path=CORPUS_PATH) -> dict:
#     return joblib.load(os.path.join(load_path, f'{json_corpus_name}.joblib'))

# def parsejson(text_json:dict):
#     text_forms = []
#     text_lemma = []
#     text_normalised = []
    
#     text_signs = []
#     text_signs_gdl = []

#     text_forms_POS = []
#     text_lemma_POS = []
#     text_normalised_POS = []

#     def extract_from_node(obj):
#         if isinstance(obj, dict):
#             if obj.get("node") == "l" and isinstance(obj.get("f"), dict):
#                 f = obj["f"]

#                 pos  = f.get("pos") or f.get("epos")
#                 form = f.get("form")
#                 lemma = f.get("cf")
#                 norm = f.get("norm") or f.get("norm0")

#                 text_forms.append(form)
#                 text_lemma.append(lemma)
#                 text_normalised.append(norm)

#                 if pos in PN_POSs:
#                     text_forms_POS.append(f"PN_{pos}")
#                     text_lemma_POS.append(f"PN_{pos}")
#                     text_normalised_POS.append(f"PN_{pos}")
#                 else:
#                     text_forms_POS.append(form)
#                     text_lemma_POS.append(lemma)
#                     text_normalised_POS.append(norm)

#                 for g in f.get("gdl", []):
#                     if isinstance(g, dict):
#                         if "v" in g:
#                             text_signs.append(g["v"])
#                         if "gdl_sign" in g:
#                             text_signs_gdl.append(g["gdl_sign"])
#                         for sub in g.get("seq", []):
#                             if "v" in sub:
#                                 text_signs.append(sub["v"])
#                             if "gdl_sign" in sub:
#                                 text_signs_gdl.append(sub["gdl_sign"])
#             for value in obj.values():
#                 extract_from_node(value)
#         elif isinstance(obj, list):
#             for item in obj:
#                 extract_from_node(item)

#     def change_unknowns(input_list:list):
#         unknowns = [None, 'x', 'X']
#         return ["■" if item in unknowns else item for item in input_list]

#     extract_from_node(text_json)

#     text_forms = change_unknowns(text_forms)
#     text_lemma = change_unknowns(text_lemma)
#     text_normalised = change_unknowns(text_normalised)
#     text_signs = change_unknowns(text_signs)
#     text_signs_gdl = change_unknowns(text_signs_gdl)
#     text_forms_POS = change_unknowns(text_forms_POS)
#     text_lemma_POS = change_unknowns(text_lemma_POS)
#     text_normalised_POS = change_unknowns(text_normalised_POS)

#     return {'text_forms': text_forms, 'text_lemma': text_lemma, 'text_normalised': text_normalised, 'text_signs': text_signs, 'text_signs_gdl': text_signs_gdl, 'text_forms_POS': text_forms_POS, 'text_lemma_POS': text_lemma_POS, 'text_normalised_POS': text_normalised_POS}


# def normalize_signs(input_: str) -> str:
#     """ Normalises signs in the text (e.g., ša = ša₂ = ša₃)"""
#     for num in '₁₂₃₄₅₆₇₈₉₀':
#         while num in input_:
#             input_ = input_.replace(num, '')

#     return input_


# def normalize_signs_list(input_: list) -> list:
#     """ Normalises signs in the text (e.g., ša = ša₂ = ša₃)"""
#     return [normalize_signs(item) for item in input_]

# class OraccProjectCorpus:
#     def __init__(self, json_corpus):
#         self.corpus = json_corpus
#         self.texts =  [text_id for text_id in json_corpus]
#         self.texts_data = [json_corpus[text_id] for text_id in json_corpus]
#         self.size = len(json_corpus)
        
#         analysed_corpus, texts_with_errors, empty_texts = self.AnalyseCorpus()
        
#         self.Lemma = analysed_corpus['lemma']
#         self.Forms = analysed_corpus['forms']
#         self.Normalised = analysed_corpus['normalised']
#         self.Signs = analysed_corpus['signs']
#         self.SignsGDL = analysed_corpus['signs_gdl']
        
#         self.LemmaPOS = analysed_corpus.get('lemma_POS')
#         self.FormsPOS = analysed_corpus.get('forms_POS')
#         self.NormalisedPOS = analysed_corpus.get('normalised_POS')
        
#         self.SignsNormalised = analysed_corpus.get('signs_normalised')
#         self.FormsNormalised = analysed_corpus.get('forms_normalised')
#         self.FormsPOSNormalised = analysed_corpus.get('forms_POS_normalised')

#         self.texts_with_errors = texts_with_errors
#         self.empty_texts = empty_texts
    
#     def AnalyseCorpus(self) -> dict: 
#         texts_with_errors = []
#         empty_texts = []
        
#         corpus_data = {}
        
#         full_corpus_forms = []
#         full_corpus_lemma = []
#         full_corpus_normalised = []
#         full_corpus_signs = []
#         full_corpus_signs_gdl = []
#         full_corpus_forms_POS = []
#         full_corpus_lemma_POS = []
#         full_corpus_normalised_POS = []
#         full_corpus_signs_normalised = []
#         full_corpus_forms_normalised = []
#         full_corpus_forms_POS_normalised = []

#         # print('\tAnalyzing texts in the corpus.', self.size, 'texts to be processed.')
        
#         for text_id in self.texts:
            
#             try:
#                 text_analysed = parsejson(self.corpus[text_id])
#             except:
#                 # TODO: find out the problems with these texts!
#                 #print('ERROR with a text:', text_id)
#                 texts_with_errors.append(text_id)
#                 text_analysed = {'text_forms': [], 'text_lemma': [], 'text_normalised': [], 'text_signs': [], 'text_signs_gdl': [], 'text_forms_POS': [], 'text_lemma_POS': [], 'text_normalised_POS': [], 'signs_normalised': [], 'forms_normalised': [], 'forms_POS_normalised': []}

#             corpus_data[text_id] = text_analysed
            
#             full_corpus_forms.append(text_analysed['text_forms'])
#             full_corpus_lemma.append(text_analysed['text_lemma'])
#             full_corpus_normalised.append(text_analysed['text_normalised'])
#             full_corpus_signs.append(text_analysed['text_signs'])
#             full_corpus_signs_gdl.append(text_analysed['text_signs_gdl'])
#             full_corpus_forms_POS.append(text_analysed['text_forms_POS'])
#             full_corpus_lemma_POS.append(text_analysed['text_lemma_POS'])
#             full_corpus_normalised_POS.append(text_analysed['text_normalised_POS'])

#             full_corpus_signs_normalised.append(normalize_signs_list(text_analysed['text_signs']))
#             full_corpus_forms_normalised.append(normalize_signs_list(text_analysed['text_forms']))
#             full_corpus_forms_POS_normalised.append(normalize_signs_list(text_analysed['text_forms_POS']))

#             if text_analysed == {'text_forms': [], 'text_lemma': [], 'text_normalised': [], 'text_signs': [], 'text_signs_gdl': [], 'text_forms_POS': [], 'text_lemma_POS': [], 'text_normalised_POS': [], 'signs_normalised': [], 'forms_normalised': [], 'forms_POS_normalised': []}:
#                 empty_texts.append(text_id)

#         return {'corpus_data': corpus_data, 'forms': full_corpus_forms, 'lemma': full_corpus_lemma, 'normalised': full_corpus_normalised, 'signs': full_corpus_signs, 'signs_gdl': full_corpus_signs_gdl, 'forms_POS': full_corpus_forms_POS, 'lemma_POS': full_corpus_lemma_POS, 'normalised_POS': full_corpus_normalised_POS, 'signs_normalised': full_corpus_signs_normalised, 'forms_normalised': full_corpus_forms_normalised, 'forms_POS_normalised': full_corpus_forms_POS_normalised}, texts_with_errors, empty_texts


# class OraccCorpus():
#     def __init__(self, input_projects:dict) -> None:
#         self.projects = input_projects

#         all_texts = []
        
#         lemma_corpus = []
#         forms_corpus = []
#         normalised_corpus = []
#         signs_corpus = []
#         signs_gdl_corpus = []
#         forms_POS_corpus = []
#         lemma_POS_corpus = []
#         normalised_POS_corpus = []
#         signs_normalised_corpus = []
#         forms_normalised_corpus = []
#         forms_POS_normalised_corpus = []

#         texts_with_errors = []
#         empty_texts = []
    
#         for project_name, project_data in tqdm(input_projects.items(), desc='Processing projects'):
#             #print(project_name, 'is being processed for dictionary.')
#             OPC_project = OraccProjectCorpus(json_corpus=project_data)
#             for text in OPC_project.Lemma:
#                 lemma_corpus.append(text)
            
#             for text in OPC_project.Forms:
#                 forms_corpus.append(text)
                
#             for text in OPC_project.Normalised:
#                 normalised_corpus.append(text)

#             for text in OPC_project.Signs:
#                 signs_corpus.append(text)
            
#             for text in OPC_project.SignsGDL:
#                 signs_gdl_corpus.append(text)

#             for text_id in OPC_project.texts_with_errors:
#                 texts_with_errors.append(text_id)

#             for text_id in OPC_project.empty_texts:
#                 empty_texts.append(text_id)

#             for text in OPC_project.FormsPOS:
#                 forms_POS_corpus.append(text)

#             for text in OPC_project.LemmaPOS:
#                 lemma_POS_corpus.append(text)

#             for text in OPC_project.NormalisedPOS:
#                 normalised_POS_corpus.append(text)

#             for text_id in OPC_project.texts:
#                 all_texts.append(text_id)

#             for text in OPC_project.SignsNormalised:
#                 signs_normalised_corpus.append(text)

#             for text in OPC_project.FormsNormalised:
#                 forms_normalised_corpus.append(text)

#             for text in OPC_project.FormsPOSNormalised:
#                 forms_POS_normalised_corpus.append(text)

#         print('Corpus size:', len(lemma_corpus), 'texts.')
#         print('Texts with errors:', len(texts_with_errors), 'texts.')
#         for text_id in texts_with_errors:
#             print('\t', text_id)

#         print('Empty texts:', len(empty_texts), 'texts.')

#         self.texts = all_texts
#         self.lemma_corpus = lemma_corpus
#         self.forms_corpus = forms_corpus
#         self.normalised_corpus = normalised_corpus
        
#         self.signs_corpus = signs_corpus
#         self.signs_gdl_corpus = signs_gdl_corpus
        
#         self.forms_POS_corpus = forms_POS_corpus
#         self.lemma_POS_corpus = lemma_POS_corpus
#         self.normalised_POS_corpus = normalised_POS_corpus

#         self.signs_normalised_corpus = signs_normalised_corpus
#         self.forms_normalised_corpus = forms_normalised_corpus
#         self.forms_POS_normalised_corpus = forms_POS_normalised_corpus

#         self.texts_with_errors = texts_with_errors
#         self.empty_texts = empty_texts


#     def get_data_by_id(self, text_id, mode='forms', print_=False) -> list:
#         """ Print text data for debugging purposes. """
#         try:
#             txt_idx = self.texts.index(text_id)
#         except ValueError:
#             print(f'Text ID {text_id} not found in the corpus.')
#             return []

#         if mode == 'forms':
#             if print_:
#                 print(f'Forms: {self.forms_corpus[txt_idx]}')
#             return self.forms_corpus[txt_idx]
#         elif mode == 'lemma':
#             if print_:
#                 print(f'Lemmas: {self.lemma_corpus[txt_idx]}')
#             return self.lemma_corpus[txt_idx]
#         elif mode == 'normalised':
#             if print_:
#                 print(f'Normalised: {self.normalised_corpus[txt_idx]}')
#             return self.normalised_corpus[txt_idx]
#         elif mode == 'signs':
#             if print_:
#                 print(f'Signs: {self.signs_corpus[txt_idx]}')
#             return self.signs_corpus[txt_idx]
#         elif mode == 'signs_gdl':
#             if print_:
#                 print(f'Signs GDL: {self.signs_gdl_corpus[txt_idx]}')
#             return self.signs_gdl_corpus[txt_idx]
#         elif mode == 'forms_pos':
#             if print_:
#                 print(f'Forms POS: {self.forms_POS_corpus[txt_idx]}')
#             return self.forms_POS_corpus[txt_idx]
#         elif mode == 'lemma_pos':
#             if print_:
#                 print(f'Lemmas POS: {self.lemma_POS_corpus[txt_idx]}')
#             return self.lemma_POS_corpus[txt_idx]
#         elif mode == 'normalised_pos':
#             if print_:
#                 print(f'Normalised POS: {self.normalised_POS_corpus[txt_idx]}')
#             return self.normalised_POS_corpus[txt_idx]
#         elif mode == 'signs_normalised':
#             if print_:
#                 print(f'Signs Normalised: {self.signs_normalised_corpus[txt_idx]}')
#             return self.signs_normalised_corpus[txt_idx]
#         elif mode == 'forms_normalised':
#             if print_:
#                 print(f'Forms Normalised: {self.forms_normalised_corpus[txt_idx]}')
#             return self.forms_normalised_corpus[txt_idx]
#         elif mode == 'forms_pos_normalised':
#             if print_:
#                 print(f'Forms POS Normalised: {self.forms_POS_normalised_corpus[txt_idx]}')
#             return self.forms_POS_normalised_corpus[txt_idx]
#         else:
#             if print_:
#                 print(f'Mode print set wrong! Use "forms", "lemma", "normalised", "signs", "signs_gdl", "forms_pos", "lemma_pos", "normalised_pos", "signs_normalised", "forms_normalised", or "forms_pos_normalised".')
#             return []

In [123]:
PN_POSs = ['AN', 'CN', 'DN', 'EN', 'FN', 'GN', 'LN', 'MN', 'ON', 'PN', 'QN', 'RN', 'SN', 'TN', 'WN', 'YN']

def load_json_corpus(json_corpus_name:str, load_path=CORPUS_PATH) -> dict:
    return joblib.load(os.path.join(load_path, f'{json_corpus_name}.joblib'))

def parsejson(text_json:dict):
    text_forms = []
    text_lemma = []
    text_normalised = []
    
    text_signs = []
    text_signs_gdl = []

    text_forms_POS = []
    text_lemma_POS = []
    text_normalised_POS = []

    named_entities_in_json = defaultdict(list)

    def extract_from_node(obj):
        if isinstance(obj, dict):
            if obj.get("node") == "l" and isinstance(obj.get("f"), dict):
                f = obj["f"]

                pos  = f.get("pos") or f.get("epos")
                form = f.get("form")
                lemma = f.get("cf")
                norm = f.get("norm") or f.get("norm0")

                text_forms.append(form)
                text_lemma.append(lemma)
                text_normalised.append(norm)

                if pos in PN_POSs:
                    text_forms_POS.append(f"PN_{pos}")
                    text_lemma_POS.append(f"PN_{pos}")
                    text_normalised_POS.append(f"PN_{pos}")

                    named_entities_in_json[f"PN_{pos}"].append(lemma)

                else:
                    text_forms_POS.append(form)
                    text_lemma_POS.append(lemma)
                    text_normalised_POS.append(norm)

                for g in f.get("gdl", []):
                    if isinstance(g, dict):
                        if "v" in g:
                            text_signs.append(g["v"])
                        if "gdl_sign" in g:
                            text_signs_gdl.append(g["gdl_sign"])
                        for sub in g.get("seq", []):
                            if "v" in sub:
                                text_signs.append(sub["v"])
                            if "gdl_sign" in sub:
                                text_signs_gdl.append(sub["gdl_sign"])

                # for g in f.get("gdl", []):
                #     if isinstance(g, dict):
                #         if "v" in g:
                #             text_signs.append(sys.intern(g["v"]))
                #         if "gdl_sign" in g:
                #             text_signs_gdl.append(sys.intern(g["gdl_sign"]))
                #         for sub in g.get("seq", []):
                #             if "v" in sub:
                #                 text_signs.append(sys.intern(sub["v"]))
                #             if "gdl_sign" in sub:
                #                 text_signs_gdl.append(sys.intern(sub["gdl_sign"]))

            for value in obj.values():
                extract_from_node(value)
        elif isinstance(obj, list):
            for item in obj:
                extract_from_node(item)

    def change_unknowns(input_list:list):
        unknowns = [None, 'x', 'X']
        return ["■" if item in unknowns else item for item in input_list]

    extract_from_node(text_json)

    text_forms = change_unknowns(text_forms)
    text_lemma = change_unknowns(text_lemma)
    text_normalised = change_unknowns(text_normalised)
    text_signs = change_unknowns(text_signs)
    text_signs_gdl = change_unknowns(text_signs_gdl)
    text_forms_POS = change_unknowns(text_forms_POS)
    text_lemma_POS = change_unknowns(text_lemma_POS)
    text_normalised_POS = change_unknowns(text_normalised_POS)

    return {'text_forms': text_forms, 'text_lemma': text_lemma, 'text_normalised': text_normalised, 'text_signs': text_signs, 'text_signs_gdl': text_signs_gdl, 'text_forms_POS': text_forms_POS, 'text_lemma_POS': text_lemma_POS, 'text_normalised_POS': text_normalised_POS, 'text_named_entities': named_entities_in_json}


SUB_NUM = str.maketrans('', '', '₀₁₂₃₄₅₆₇₈₉')

def normalize_signs(s: str) -> str:
    return s.translate(SUB_NUM)

def normalize_signs_list(lst: list) -> list:
    return [sys.intern(normalize_signs(x) if x is not None else "■") for x in lst]

class OraccProjectCorpus:
    def __init__(self, json_corpus):
        self.corpus = json_corpus
        self.texts =  [text_id for text_id in json_corpus]
        self.texts_data = [json_corpus[text_id] for text_id in json_corpus]
        self.size = len(json_corpus)
        
        analysed_corpus, texts_with_errors, empty_texts = self.AnalyseCorpus()
        
        self.Lemma = analysed_corpus['lemma']
        self.Forms = analysed_corpus['forms']
        self.Normalised = analysed_corpus['normalised']
        self.Signs = analysed_corpus['signs']
        self.SignsGDL = analysed_corpus['signs_gdl']
        
        self.LemmaPOS = analysed_corpus.get('lemma_POS')
        self.FormsPOS = analysed_corpus.get('forms_POS')
        self.NormalisedPOS = analysed_corpus.get('normalised_POS')
        
        self.SignsNormalised = analysed_corpus.get('signs_normalised')
        self.FormsNormalised = analysed_corpus.get('forms_normalised')
        self.FormsPOSNormalised = analysed_corpus.get('forms_POS_normalised')

        self.texts_with_errors = texts_with_errors
        self.empty_texts = empty_texts
        self.named_entities = analysed_corpus.get('named_entities')

    def AnalyseCorpus(self) -> dict: 
        texts_with_errors = []
        empty_texts = []
        
        corpus_data = {}
        
        full_corpus_forms = []
        full_corpus_lemma = []
        full_corpus_normalised = []
        full_corpus_signs = []
        full_corpus_signs_gdl = []
        full_corpus_forms_POS = []
        full_corpus_lemma_POS = []
        full_corpus_normalised_POS = []
        full_corpus_signs_normalised = []
        full_corpus_forms_normalised = []
        full_corpus_forms_POS_normalised = []
        
        full_corpus_named_entities = []

        # print('\tAnalyzing texts in the corpus.', self.size, 'texts to be processed.')
        
        for text_id in self.texts:
            
            try:
                text_analysed = parsejson(self.corpus[text_id])
            except:
                # TODO: find out the problems with these texts!
                #print('ERROR with a text:', text_id)
                texts_with_errors.append(text_id)
                text_analysed = {'text_forms': [], 'text_lemma': [], 'text_normalised': [], 'text_signs': [], 'text_signs_gdl': [], 'text_forms_POS': [], 'text_lemma_POS': [], 'text_normalised_POS': [], 'signs_normalised': [], 'forms_normalised': [], 'forms_POS_normalised': [], 'named_entities': []}

            corpus_data[text_id] = text_analysed
            
            full_corpus_forms.append(text_analysed['text_forms'])
            full_corpus_lemma.append(text_analysed['text_lemma'])
            full_corpus_normalised.append(text_analysed['text_normalised'])
            full_corpus_signs.append(text_analysed['text_signs'])
            full_corpus_signs_gdl.append(text_analysed['text_signs_gdl'])
            full_corpus_forms_POS.append(text_analysed['text_forms_POS'])
            full_corpus_lemma_POS.append(text_analysed['text_lemma_POS'])
            full_corpus_normalised_POS.append(text_analysed['text_normalised_POS'])

            full_corpus_signs_normalised.append(normalize_signs_list(text_analysed['text_signs']))
            full_corpus_forms_normalised.append(normalize_signs_list(text_analysed['text_forms']))
            full_corpus_forms_POS_normalised.append(normalize_signs_list(text_analysed['text_forms_POS']))

            full_corpus_named_entities.append(text_analysed['text_named_entities'])

            if text_analysed == {'text_forms': [], 'text_lemma': [], 'text_normalised': [], 'text_signs': [], 'text_signs_gdl': [], 'text_forms_POS': [], 'text_lemma_POS': [], 'text_normalised_POS': [], 'signs_normalised': [], 'forms_normalised': [], 'forms_POS_normalised': [], 'named_entities': []}:
                empty_texts.append(text_id)

        return {'corpus_data': corpus_data, 'forms': full_corpus_forms, 'lemma': full_corpus_lemma, 'normalised': full_corpus_normalised, 'signs': full_corpus_signs, 'signs_gdl': full_corpus_signs_gdl, 'forms_POS': full_corpus_forms_POS, 'lemma_POS': full_corpus_lemma_POS, 'normalised_POS': full_corpus_normalised_POS, 'signs_normalised': full_corpus_signs_normalised, 'forms_normalised': full_corpus_forms_normalised, 'forms_POS_normalised': full_corpus_forms_POS_normalised, 'named_entities': full_corpus_named_entities}, texts_with_errors, empty_texts


class OraccCorpus():
    def __init__(self, projects_path:str, files_prefix:str='prnd_') -> None: # def __init__(self, input_projects:dict) -> None:
        # self.projects = input_projects # not needed, RAM saving

        all_texts = []
        
        lemma_corpus = []
        forms_corpus = []
        normalised_corpus = []
        signs_corpus = []
        signs_gdl_corpus = []
        forms_POS_corpus = []
        lemma_POS_corpus = []
        normalised_POS_corpus = []
        signs_normalised_corpus = []
        forms_normalised_corpus = []
        forms_POS_normalised_corpus = []
        named_entities_corpus = []

        texts_with_errors = []
        empty_texts = []
    
        # for project_name, project_data in tqdm(input_projects.items(), desc='Processing projects'):
        # for project_name in tqdm(list(input_projects.keys()), desc='Processing projects...'):
        # i=0
        for project_file in tqdm(os.listdir(projects_path), desc='Processing project files...'):
            # project_data = input_projects.pop(project_name)  # saving RAM (but still working with full dataset --> not good)

            # if i >= 10:
            #     break

            if project_file.startswith(files_prefix) and project_file.endswith('.joblib'):
                project_data = load_json_corpus(project_file[:-7], load_path=projects_path)
                OPC_project = OraccProjectCorpus(json_corpus=project_data)

                for text in OPC_project.Lemma:
                    lemma_corpus.append(text)
                
                for text in OPC_project.Forms:
                    forms_corpus.append(text)
                    
                for text in OPC_project.Normalised:
                    normalised_corpus.append(text)

                for text in OPC_project.Signs:
                    signs_corpus.append(text)
                
                for text in OPC_project.SignsGDL:
                    signs_gdl_corpus.append(text)

                for text_id in OPC_project.texts_with_errors:
                    texts_with_errors.append(text_id)

                for text_id in OPC_project.empty_texts:
                    empty_texts.append(text_id)

                for text in OPC_project.FormsPOS:
                    forms_POS_corpus.append(text)

                for text in OPC_project.LemmaPOS:
                    lemma_POS_corpus.append(text)

                for text in OPC_project.NormalisedPOS:
                    normalised_POS_corpus.append(text)

                for text_id in OPC_project.texts:
                    all_texts.append(text_id)

                for text in OPC_project.SignsNormalised:
                    signs_normalised_corpus.append(text)

                for text in OPC_project.FormsNormalised:
                    forms_normalised_corpus.append(text)

                for text in OPC_project.FormsPOSNormalised:
                    forms_POS_normalised_corpus.append(text)

                for text in OPC_project.named_entities:
                    named_entities_corpus.append(text)

                del project_data, OPC_project # saving RAM
                gc.collect() # saving RAM

                # i+=1
            
            else:
                continue

        print('Corpus size:', len(lemma_corpus), 'texts.')
        print('Texts with errors:', len(texts_with_errors), 'texts.')
        
        for text_id in texts_with_errors:
            print('\t', text_id)

        print('Empty texts:', len(empty_texts), 'texts.')

        self.texts = all_texts
        self.lemma_corpus = lemma_corpus
        self.forms_corpus = forms_corpus
        self.normalised_corpus = normalised_corpus
        
        self.signs_corpus = signs_corpus
        self.signs_gdl_corpus = signs_gdl_corpus
        
        self.forms_POS_corpus = forms_POS_corpus
        self.lemma_POS_corpus = lemma_POS_corpus
        self.normalised_POS_corpus = normalised_POS_corpus

        self.signs_normalised_corpus = signs_normalised_corpus
        self.forms_normalised_corpus = forms_normalised_corpus
        self.forms_POS_normalised_corpus = forms_POS_normalised_corpus

        self.texts_with_errors = texts_with_errors
        self.empty_texts = empty_texts

        self.named_entities = named_entities_corpus


    def get_data_by_id(self, text_id, mode='forms', print_=False) -> list:
        """ Print text data for debugging purposes. """
        try:
            txt_idx = self.texts.index(text_id)
        except ValueError:
            print(f'Text ID {text_id} not found in the corpus.')
            return []

        if mode == 'forms':
            if print_:
                print(f'Forms: {self.forms_corpus[txt_idx]}')
            return self.forms_corpus[txt_idx]
        elif mode == 'lemma':
            if print_:
                print(f'Lemmas: {self.lemma_corpus[txt_idx]}')
            return self.lemma_corpus[txt_idx]
        elif mode == 'normalised':
            if print_:
                print(f'Normalised: {self.normalised_corpus[txt_idx]}')
            return self.normalised_corpus[txt_idx]
        elif mode == 'signs':
            if print_:
                print(f'Signs: {self.signs_corpus[txt_idx]}')
            return self.signs_corpus[txt_idx]
        elif mode == 'signs_gdl':
            if print_:
                print(f'Signs GDL: {self.signs_gdl_corpus[txt_idx]}')
            return self.signs_gdl_corpus[txt_idx]
        elif mode == 'forms_pos':
            if print_:
                print(f'Forms POS: {self.forms_POS_corpus[txt_idx]}')
            return self.forms_POS_corpus[txt_idx]
        elif mode == 'lemma_pos':
            if print_:
                print(f'Lemmas POS: {self.lemma_POS_corpus[txt_idx]}')
            return self.lemma_POS_corpus[txt_idx]
        elif mode == 'normalised_pos':
            if print_:
                print(f'Normalised POS: {self.normalised_POS_corpus[txt_idx]}')
            return self.normalised_POS_corpus[txt_idx]
        elif mode == 'signs_normalised':
            if print_:
                print(f'Signs Normalised: {self.signs_normalised_corpus[txt_idx]}')
            return self.signs_normalised_corpus[txt_idx]
        elif mode == 'forms_normalised':
            if print_:
                print(f'Forms Normalised: {self.forms_normalised_corpus[txt_idx]}')
            return self.forms_normalised_corpus[txt_idx]
        elif mode == 'forms_pos_normalised':
            if print_:
                print(f'Forms POS Normalised: {self.forms_POS_normalised_corpus[txt_idx]}')
            return self.forms_POS_normalised_corpus[txt_idx]
        elif mode == 'named_entities':
            if print_:
                print(f'Named Entities: {self.named_entities[txt_idx]}')
            return self.named_entities[txt_idx]
        else:
            if print_:
                print(f'Mode print set wrong! Use "forms", "lemma", "normalised", "signs", "signs_gdl", "forms_pos", "lemma_pos", "normalised_pos", "signs_normalised", "forms_normalised", or "forms_pos_normalised".')
            return []

## Preparing the Corpus

In [None]:
# all_project_jsons = load_json_corpus('all_projects_jsons')
# all_project_jsons_filtered = load_json_corpus('all_projects_jsons_filtered')
example_corpus = load_json_corpus('prnd_no_compakklove')  # Load corpus from individual project files
for key in example_corpus.keys():
    print(key, example_corpus[key])  # Print number of texts in each project

In [124]:
# full_ORACC_corpus = OraccCorpus(all_project_jsons)
# full_ORACC_corpus = OraccCorpus(all_project_jsons_filtered)
full_ORACC_corpus = OraccCorpus(CORPUS_PATH, files_prefix='prnd_no_comp')  # loading from individual pruned files

Processing project files...: 100%|██████████| 156/156 [04:34<00:00,  1.76s/it]

Corpus size: 28943 texts.
Texts with errors: 0 texts.
Empty texts: 0 texts.





In [133]:
def analyse_realtions_of_NE(NE_for_analysis:str, category:str, oracc_corpus:OraccCorpus):
    # Function to analyze relations of named entities

    related_entities = defaultdict(int)
    
    for text_id in full_ORACC_corpus.texts:
        named_entities_in_text = full_ORACC_corpus.get_data_by_id(text_id, mode='named_entities')

        if NE_for_analysis in named_entities_in_text[category]:
            for ne in set(named_entities_in_text[category]):
                if ne != NE_for_analysis:
                    related_entities[ne] += 1

    return related_entities


In [None]:
q_en = 'Adad'
entity = 'PN_DN'

relations_instance = analyse_realtions_of_NE(q_en, entity, oracc_corpus=full_ORACC_corpus)

benchmark = 30

out_dict = {}
out_dict_points = {}
i = 0
for rel, val in relations_instance.items():
    if val >= benchmark:
        out_dict[i]={'query_ne': q_en, 'rel_ne': rel, 'hits': val}
        out_dict_points[i] = {'point': rel, 'size': val}  # Example point calculation
        i += 1

out_dict_points[i+1] = {'point': q_en, 'size': len(relations_instance)}  # Example point calculation

df_Adad = pd.DataFrame.from_dict(out_dict, orient='index')
df_Adad.to_csv(f'flourishdata/{q_en}_relations.csv', sep=',')

df_Adad_points = pd.DataFrame.from_dict(out_dict_points, orient='index')
df_Adad_points.to_csv(f'flourishdata/{q_en}_relations_points.csv', sep=',')

In [120]:
full_ORACC_corpus.get_data_by_id(full_ORACC_corpus.texts[20], mode='named_entities', print_=True)
# full_ORACC_corpus.get_data_by_id(full_ORACC_corpus.texts[20], mode='lemma', print_=True)

Named Entities: defaultdict(<class 'list'>, {'PN_CN': ['Arcturus', 'Ṣalbaṭānu', 'Ṣalbaṭānu', 'Ṣalbaṭānu', 'MÚL-KUR-šá-DUR-nu-nu', 'Dilbat', 'DELE-šá-IGI-ABSIN', 'MÚL-ár-šá-SAG-HUN', 'MÚL.MÚL', 'is-le₁₀', 'Ṣalbaṭānu', 'MÚL-KUR-šá-DUR-nu-nu', 'Šihṭu', 'MÚL-IGI-šá-še-pít-MAŠ.MAŠ', 'MAŠ.MAŠ-IGI', 'MÚL-ár-šá-ALLA-šá-ULÙ', 'SAG-A', 'LUGAL', 'GIŠ.KUN-A', 'GÌR-ár-šá-A', 'DELE-šá-IGI-ABSIN', 'Kayyamānu', 'Kakkabu-peṣû', 'Kayyamānu', 'SA₄-šá-ABSIN', 'Dilbat', 'Ṣalbaṭānu', 'Šihṭu', 'RÍN-šá-SI', 'Kakkabu-peṣû', 'Kakkabu-peṣû', 'RÍN-šá-ULÙ', 'Dilbat', 'SA₄-šá-ABSIN', 'Kakkabu-peṣû', 'Zibānītu', 'Dilbat', 'Šerʾu', 'Kayyamānu', 'Zibānītu', 'Ṣalbaṭānu', 'Zibbātu', 'Šihṭu', 'SI-MÁŠ', 'MÚL-IGI-šá-SUHUR-MÁŠ', 'MÚL-KUR-šá-DUR-nu-nu', 'MÚL-IGI-šá-SAG-HUN', 'ŠUR-GIGIR-šá-SI', 'Šihṭu', 'Zuqiqīpu', 'SI₄', 'MAŠ.MAŠ-ár', 'Šihṭu', 'SI₄', 'DELE-šá-IGI-ABSIN', 'Dilbat', 'Kakkabu-peṣû', 'RÍN-šá-ULÙ', 'MÚL-KUR-šá-KIR₄-šil-PA', 'Kakkabu-peṣû', 'Zibānītu', 'Dilbat', 'Zibānītu', 'Zuqiqīpu', 'Šihṭu', 'is-le₁₀', 'ŠUR-GIG

defaultdict(list,
            {'PN_CN': ['Arcturus',
              'Ṣalbaṭānu',
              'Ṣalbaṭānu',
              'Ṣalbaṭānu',
              'MÚL-KUR-šá-DUR-nu-nu',
              'Dilbat',
              'DELE-šá-IGI-ABSIN',
              'MÚL-ár-šá-SAG-HUN',
              'MÚL.MÚL',
              'is-le₁₀',
              'Ṣalbaṭānu',
              'MÚL-KUR-šá-DUR-nu-nu',
              'Šihṭu',
              'MÚL-IGI-šá-še-pít-MAŠ.MAŠ',
              'MAŠ.MAŠ-IGI',
              'MÚL-ár-šá-ALLA-šá-ULÙ',
              'SAG-A',
              'LUGAL',
              'GIŠ.KUN-A',
              'GÌR-ár-šá-A',
              'DELE-šá-IGI-ABSIN',
              'Kayyamānu',
              'Kakkabu-peṣû',
              'Kayyamānu',
              'SA₄-šá-ABSIN',
              'Dilbat',
              'Ṣalbaṭānu',
              'Šihṭu',
              'RÍN-šá-SI',
              'Kakkabu-peṣû',
              'Kakkabu-peṣû',
              'RÍN-šá-ULÙ',
              'Dilbat',
              'SA₄-šá

In [None]:
# del all_project_jsons
# del all_project_jsons_filtered

In [None]:
""" Validity check """

# for mode in ['forms', 'lemma', 'normalised', 'signs', 'signs_gdl', 'forms_POS', 'lemma_POS', 'normalised_POS']:
#     for t_id in full_ORACC_corpus.texts:
#         if full_ORACC_corpus.get_data_by_id(t_id, mode=mode) == full_ORACC_corpus_filtered.get_data_by_id(t_id, mode=mode):
#             continue
#         else:
#             print("Mismatch found in text ID:", t_id)

## Intertextuality based on vectorization

In [5]:
import pandas as pd
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm
import os
import faiss
from typing import List, Tuple, Dict, Any
import csv

ROOT_PATH = os.getcwd()
CHUNKS_PATH = os.path.join(ROOT_PATH, "chunks")

ORACC_NORM_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_norm_embeddings.csv")
ORACC_NORM_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_norm_meta.csv")

ORACC_LEMMA_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_lemma_embeddings.csv")
ORACC_LEMMA_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_lemma_meta.csv")

ORACC_FORMS_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_forms_embeddings.csv")
ORACC_FORMS_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_forms_meta.csv")

ORACC_NORM_POS_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_norm_pos_embeddings.csv")
ORACC_NORM_POS_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_norm_pos_meta.csv")

ORACC_LEMMA_POS_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_lemma_pos_embeddings.csv")
ORACC_LEMMA_POS_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_lemma_pos_meta.csv")

ORACC_FORMS_POS_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_forms_pos_embeddings.csv")
ORACC_FORMS_POS_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_forms_pos_meta.csv")

ORACC_FORMS_NORMALISED_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_forms_normalised_embeddings.csv")
ORACC_FORMS_NORMALISED_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_forms_normalised_meta.csv")

ORACC_FORMS_POS_NORMALISED_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_forms_pos_normalised_embeddings.csv")
ORACC_FORMS_POS_NORMALISED_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_forms_pos_normalised_meta.csv")

ORACC_SIGNS_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_signs_embeddings.csv")
ORACC_SIGNS_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_signs_meta.csv")

ORACC_SIGNS_NORMALISED_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_signs_normalised_embeddings.csv")
ORACC_SIGNS_NORMALISED_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_signs_normalised_meta.csv")

ORACC_SIGNSGDL_embed_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_signs_gdl_embeddings.csv")
ORACC_SIGNSGDL_meta_csv_PATH = os.path.join(CHUNKS_PATH, "oracc_signs_gdl_meta.csv")

EMBS_NORM_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_normalised_e5.npy")
EMBS_NORM_POS_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_normalised_pos_e5.npy")
EMBS_LEMMA_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_lemma_e5.npy")
EMBS_LEMMA_POS_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_lemma_pos_e5.npy")
EMBS_FORMS_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_forms_e5.npy")
EMBS_FORMS_NORMALISED_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_forms_normalised_e5.npy")
EMBS_FORMS_POS_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_forms_pos_e5.npy")
EMBS_FORMS_POS_NORMALISED_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_forms_pos_normalised_e5.npy")
EMBS_SIGNS_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_signs_e5.npy")
EMBS_SIGNS_NORMALISED_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_signs_normalised_e5.npy")
EMBS_SIGNSGDL_PATH_E5 = os.path.join(CHUNKS_PATH,"embeddings_signs_gdl_e5.npy")

EMBS_NORM_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_normalised_miniLM.npy")
EMBS_NORM_POS_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_normalised_pos_miniLM.npy")
EMBS_LEMMA_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_lemma_miniLM.npy")
EMBS_LEMMA_POS_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_lemma_pos_miniLM.npy")
EMBS_FORMS_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_forms_miniLM.npy")
EMBS_FORMS_POS_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_forms_pos_miniLM.npy")
EMBS_FORMS_NORMALISED_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_forms_normalised_miniLM.npy")
EMBS_FORMS_POS_NORMALISED_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_forms_pos_normalised_miniLM.npy")
EMBS_SIGNS_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_signs_miniLM.npy")
EMBS_SIGNS_NORMALISED_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_signs_normalised_miniLM.npy")
EMBS_SIGNSGDL_PATH_MiniLM = os.path.join(CHUNKS_PATH,"embeddings_signs_gdl_miniLM.npy")

IDS_NORM_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_norm.csv")
IDS_NORM_POS_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_norm_pos.csv")
IDS_LEMMA_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_lemma.csv")
IDS_LEMMA_POS_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_lemma_pos.csv")
IDS_FORMS_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_forms.csv")
IDS_FORMS_POS_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_forms_pos.csv")
IDS_FORMS_NORMALISED_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_forms_normalised.csv")
IDS_FORMS_POS_NORMALISED_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_forms_pos_normalised.csv")
IDS_SIGNS_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_signs.csv")
IDS_SIGNS_NORMALISED_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_signs_normalised.csv")
IDS_SIGNSGDL_PATH = os.path.join(CHUNKS_PATH,"chunk_ids_signs_gdl.csv")

FAISS_NORM_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_norm_e5.faiss")
FAISS_NORM_POS_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_norm_pos_e5.faiss")
FAISS_LEMMA_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_lemma_e5.faiss")
FAISS_LEMMA_POS_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_lemma_pos_e5.faiss")
FAISS_FORMS_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_forms_e5.faiss")
FAISS_FORMS_POS_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_forms_pos_e5.faiss")
FAISS_FORMS_NORMALISED_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_forms_normalised_e5.faiss")
FAISS_FORMS_POS_NORMALISED_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_forms_pos_normalised_e5.faiss")
FAISS_SIGNS_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_signs_e5.faiss")
FAISS_SIGNS_NORMALISED_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_signs_normalised_e5.faiss")
FAISS_SIGNSGDL_PATH_E5 = os.path.join(CHUNKS_PATH, "oracc_signs_gdl_e5.faiss")

FAISS_NORM_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_norm_miniLM.faiss")
FAISS_NORM_POS_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_norm_pos_miniLM.faiss")
FAISS_LEMMA_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_lemma_miniLM.faiss")
FAISS_LEMMA_POS_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_lemma_pos_miniLM.faiss")
FAISS_FORMS_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_forms_miniLM.faiss")
FAISS_FORMS_POS_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_forms_pos_miniLM.faiss")
FAISS_FORMS_NORMALISED_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_forms_normalised_miniLM.faiss")
FAISS_FORMS_POS_NORMALISED_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_forms_pos_normalised_miniLM.faiss")
FAISS_SIGNS_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_signs_miniLM.faiss")
FAISS_SIGNS_NORMALISED_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_signs_normalised_miniLM.faiss")
FAISS_SIGNSGDL_PATH_MiniLM = os.path.join(CHUNKS_PATH, "oracc_signs_gdl_miniLM.faiss")

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
str_ = 'lala sdsd sdsd sfdsd sd sd  sd sd s sd sd sd sd sd a'

In [16]:
parse_query_text(str_.split())

[['lala', 'sdsd', 'sdsd', 'sfdsd', 'sd'],
 ['sfdsd', 'sd', 'sd', 'sd', 'sd'],
 ['sd', 'sd', 's', 'sd', 'sd'],
 ['sd', 'sd', 'sd', 'sd', 'sd']]

### Chunking ORACC

In [None]:
def make_windows(seq: List[str], window: int, stride: int, drop_last: bool=False) -> List[Tuple[int,int,List[str]]]:
    """
    Creates a list of strideping windows from the input sequence.

    :param seq: Input sequence (list of tokens/characters)
    :param window: Size of the window
    :param stride: Stride (step size) for moving the window
    :param drop_last: Whether to drop the last window if it's smaller than the specified size
    :return: List of tuples (start_idx, end_idx, subseq)
    """
    n = len(seq)
    out = []
    if n == 0 or window <= 0 or stride <= 0:
        return out

    i = 0
    while i < n:
        j = i + window
        if j > n:
            if drop_last:
                break
            j = n
        out.append((i, j, seq[i:j]))
        if j == n:
            break
        i += stride
    return out


def chunkORACCtext(input_orrac_corpus: OraccCorpus, oracc_text_ID:str, mode: str='normalised', chunk_size: int=10, stride: int=5, drop_last: bool=False, unknown_policy: str='compress', skip_all_unknown: bool=True) -> List[Dict[str, Any]]:
    """ Parsing ORACC to chunks. 

    :param input_orrac_corpus: The ORACC corpus object
    :param oracc_text_ID: The ID of the ORACC text to process
    :param mode: The mode for text retrieval (e.g., 'normalised')
    :param chunk_size: The size of each chunk
    :param stride: The stride between chunks
    :param drop_last: Whether to drop the last chunk if it's smaller than chunk_size
    :param unknown_policy: Whether to drop unknown words (select 'compress'|'keep')
    :param skip_all_unknown: Whether to skip chunks that are entirely unknown
    :return: A tuple containing a list of chunks as tuples (start_idx, end_idx, subseq)
    """

    oraccText = input_orrac_corpus.get_data_by_id(oracc_text_ID, mode=mode)

    windows = make_windows(oraccText, window=chunk_size, stride=stride, drop_last=drop_last)

    out = []
    for s, e, subseq_raw in windows:
        # DISPLAY data
        text_display = ' '.join(list(subseq_raw))

        # EMBEDDING data (UKNOWN handling)
        if unknown_policy == 'compress':
            emb_tokens = [t for t in subseq_raw if t != '∎']
        elif unknown_policy == 'keep':
            emb_tokens = list(subseq_raw)
        else:
            raise ValueError("unknown_policy must be 'compress'|'keep'")

        text_embed = ' '.join(emb_tokens).strip()
        if skip_all_unknown and text_embed == '' or text_embed == '∎':
            continue

        out.append({
            'start': s,
            'end': e,
            'text_display': text_display,
            'text_embed': text_embed,
        })
    return out


def export_corpus_to_csv(corpus: OraccCorpus, out_embed_csv: str, out_meta_csv: str, mode: str, chunk_size: int = 10, stride: int = 5, drop_last: bool = False, unknown_policy: str='compress', skip_all_unknown: bool=True):

    if mode=='normalised':
        unit_tag='n'
    elif mode=='forms':
        unit_tag='f'
    elif mode=='forms_normalised':
        unit_tag='fn'
    elif mode=='lemma':
        unit_tag='l'
    elif mode=='forms_pos':
        unit_tag='fp'
    elif mode=='forms_pos_normalised':
        unit_tag='fpn'
    elif mode=='lemma_pos':
        unit_tag='lp'
    elif mode=='normalised_pos':
        unit_tag='np'
    elif mode=='signs':
        unit_tag='s'
    elif mode=='signs_normalised':
        unit_tag='sn'
    elif mode=='signs_gdl':
        unit_tag='sg'
    else:
        raise ValueError(f"Unknown mode: {mode}. Use 'normalised', 'forms', 'lemma', 'normalised_pos', 'forms_pos', 'lemma_pos', 'signs', 'signs_gdl', 'forms_normalised', 'forms_POS_normalised', 'signs_normalised'.")

    with open(out_embed_csv, "w", newline="", encoding="utf-8") as fe, \
         open(out_meta_csv,  "w", newline="", encoding="utf-8") as fm:
        we = csv.writer(fe)
        wm = csv.writer(fm)
        we.writerow(["chunk_id", "text"])
        wm.writerow(["chunk_id", "text_id", "start", "end", "text_display"])

    total_embed = 0
    total_meta = 0

    # Streaming text chunks
    for textID in tqdm(corpus.texts, desc='Processing ORACC texts'):
        recs = chunkORACCtext(
            input_orrac_corpus=corpus,
            oracc_text_ID=textID,
            mode=mode,
            chunk_size=chunk_size,
            stride=stride,
            drop_last=drop_last,
            unknown_policy=unknown_policy, 
            skip_all_unknown=skip_all_unknown
        )

        with open(out_embed_csv, 'a', newline='', encoding='utf-8') as fe, \
             open(out_meta_csv,  'a', newline='', encoding='utf-8') as fm:
            we = csv.writer(fe)
            wm = csv.writer(fm)

            for r in recs:
                chunk_id = f"{textID}:{unit_tag}:{r['start']}-{r['end']}"

                # meta zapisujeme vždy (aby šlo v UI projít vše, i když embed není)
                wm.writerow([chunk_id, textID, r['start'], r['end'], r['text_display']])
                total_meta += 1

                # do embedding CSV jen neprázdné texty (tvoje logika už ∎ vyhodila)
                if r['text_embed']:
                    we.writerow([chunk_id, r['text_embed']])
                    total_embed += 1

    print(f'Saved → {out_embed_csv}: {total_embed} rows')
    print(f'Saved → {out_meta_csv}:  {total_meta} rows')


In [None]:
# export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_NORM_embed_csv_PATH, out_meta_csv=ORACC_NORM_meta_csv_PATH, mode='normalised')
# export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_LEMMA_embed_csv_PATH, out_meta_csv=ORACC_LEMMA_meta_csv_PATH, mode='lemma')
# export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_FORMS_embed_csv_PATH, out_meta_csv=ORACC_FORMS_meta_csv_PATH, mode='forms')
# export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_FORMS_NORMALISED_embed_csv_PATH, out_meta_csv=ORACC_FORMS_NORMALISED_meta_csv_PATH, mode='forms_normalised')
export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_NORM_POS_embed_csv_PATH, out_meta_csv=ORACC_NORM_POS_meta_csv_PATH, mode='normalised_pos')
export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_LEMMA_POS_embed_csv_PATH, out_meta_csv=ORACC_LEMMA_POS_meta_csv_PATH, mode='lemma_pos')
export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_FORMS_POS_embed_csv_PATH, out_meta_csv=ORACC_FORMS_POS_meta_csv_PATH, mode='forms_pos')
export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_FORMS_POS_NORMALISED_embed_csv_PATH, out_meta_csv=ORACC_FORMS_POS_NORMALISED_meta_csv_PATH, mode='forms_pos_normalised')
# export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_SIGNS_embed_csv_PATH, out_meta_csv=ORACC_SIGNS_meta_csv_PATH, mode='signs', chunk_size=25, stride=10)
# export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_SIGNS_NORMALISED_embed_csv_PATH, out_meta_csv=ORACC_SIGNS_NORMALISED_meta_csv_PATH, mode='signs_normalised', chunk_size=25, stride=10)
# export_corpus_to_csv(corpus=full_ORACC_corpus, out_embed_csv=ORACC_SIGNSGDL_embed_csv_PATH, out_meta_csv=ORACC_SIGNSGDL_meta_csv_PATH, mode='signs_gdl', chunk_size=25, stride=10)

### Embedding

In [None]:
import torch

def batched(iterable, n):
    for i in range(0, len(iterable), n):
        yield iterable[i:i+n]


def process_chunks(input_csv_path:str, output_embeddings_path:str, output_ids_path:str, model_name:str, device:str='cuda', batch_size:int=2048):
    df = pd.read_csv(input_csv_path)
    ids = df['chunk_id'].astype(str).tolist()
    texts = df['text'].astype(str).tolist()

    model = SentenceTransformer(model_name, device=device)
    all_vecs = []

    total = (len(texts) + batch_size - 1) // batch_size

    for i in tqdm(range(0, len(texts), batch_size), total=total):
        batch = texts[i:i+batch_size]
        vecs = model.encode(batch, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=False).astype('float32')
        
        all_vecs.append(vecs)

    E = np.vstack(all_vecs)
    np.save(output_embeddings_path, E)
    pd.Series(ids, name='chunk_id').to_csv(output_ids_path, index=False)

    print("Saved to:", output_embeddings_path, E.shape, " / ", output_ids_path, len(ids))

def process_chunks_POS(input_csv_path:str, output_embeddings_path:str, output_ids_path:str,
                   model_name='intfloat/e5-base-v2', device:str='cuda', batch_size:int=2048,
                   text_prefix:str='passage: '):
    df = pd.read_csv(input_csv_path)
    ids = df["chunk_id"].astype(str).tolist()
    texts = df["text"].astype(str).tolist()

    if text_prefix:
        texts = [f"{text_prefix}{t}" for t in texts]

    model = SentenceTransformer(model_name, device=device)

    model.max_seq_length = 96       # (zkus 96; když chceš víc přesnosti, dej 128)
    if device == "cuda":
        model = model.to(torch.float16)
        torch.set_float32_matmul_precision("high")

    all_vecs = []
    total = (len(texts) + batch_size - 1) // batch_size
    for i in tqdm(range(0, len(texts), batch_size), total=total):
        batch = texts[i:i+batch_size]
        vecs = model.encode(batch, batch_size=batch_size,
                            normalize_embeddings=True, show_progress_bar=False
                           ).astype("float32")
        all_vecs.append(vecs)
    E = np.vstack(all_vecs)
    np.save(output_embeddings_path, E)
    pd.Series(ids, name="chunk_id").to_csv(output_ids_path, index=False)

    print("Uloženo:", output_embeddings_path, E.shape, " / ", output_ids_path, len(ids))

def select_paths(mode='normalised', model='e5'):
    if mode == 'normalised':
        if model == 'e5':
            return (ORACC_NORM_embed_csv_PATH, EMBS_NORM_PATH_E5, IDS_NORM_PATH, FAISS_NORM_PATH_E5, 'intfloat/e5-base-v2', ORACC_NORM_meta_csv_PATH)
        elif model == 'MiniLM':
            return (ORACC_NORM_embed_csv_PATH, EMBS_NORM_PATH_MiniLM, IDS_NORM_PATH, FAISS_NORM_PATH_MiniLM, 'all-MiniLM-L6-v2', ORACC_NORM_meta_csv_PATH)
    elif mode == 'normalised_POS':
        if model == 'e5':
            return (ORACC_NORM_POS_embed_csv_PATH, EMBS_NORM_POS_PATH_E5, IDS_NORM_POS_PATH, FAISS_NORM_POS_PATH_E5, 'intfloat/e5-base-v2', ORACC_NORM_POS_meta_csv_PATH)
        elif model == 'MiniLM':
            return (ORACC_NORM_POS_embed_csv_PATH, EMBS_NORM_POS_PATH_MiniLM, IDS_NORM_POS_PATH, FAISS_NORM_POS_PATH_MiniLM, 'all-MiniLM-L6-v2', ORACC_NORM_POS_meta_csv_PATH)
    elif mode == 'lemma':
        if model == 'e5':
            return (ORACC_LEMMA_embed_csv_PATH, EMBS_LEMMA_PATH_E5, IDS_LEMMA_PATH, FAISS_LEMMA_PATH_E5, 'intfloat/e5-base-v2', ORACC_LEMMA_meta_csv_PATH)
        elif model == 'MiniLM':
            return (ORACC_LEMMA_embed_csv_PATH, EMBS_LEMMA_PATH_MiniLM, IDS_LEMMA_PATH, FAISS_LEMMA_PATH_MiniLM, 'all-MiniLM-L6-v2', ORACC_LEMMA_meta_csv_PATH)
    elif mode == 'lemma_POS':
        if model == 'e5':
            return (ORACC_LEMMA_POS_embed_csv_PATH, EMBS_LEMMA_POS_PATH_E5, IDS_LEMMA_POS_PATH, FAISS_LEMMA_POS_PATH_E5, 'intfloat/e5-base-v2', ORACC_LEMMA_POS_meta_csv_PATH)
        elif model == 'MiniLM':
            return (ORACC_LEMMA_POS_embed_csv_PATH, EMBS_LEMMA_POS_PATH_MiniLM, IDS_LEMMA_POS_PATH, FAISS_LEMMA_POS_PATH_MiniLM, 'all-MiniLM-L6-v2', ORACC_LEMMA_POS_meta_csv_PATH)
    elif mode == 'forms':
        if model == 'e5':
            return (ORACC_FORMS_embed_csv_PATH, EMBS_FORMS_PATH_E5, IDS_FORMS_PATH, FAISS_FORMS_PATH_E5, 'intfloat/e5-base-v2', ORACC_FORMS_meta_csv_PATH)
        elif model == 'MiniLM':
            return (ORACC_FORMS_embed_csv_PATH, EMBS_FORMS_PATH_MiniLM, IDS_FORMS_PATH, FAISS_FORMS_PATH_MiniLM, 'all-MiniLM-L6-v2', ORACC_FORMS_meta_csv_PATH)
    elif mode == 'forms_POS':
        if model == 'e5':
            return (ORACC_FORMS_POS_embed_csv_PATH, EMBS_FORMS_POS_PATH_E5, IDS_FORMS_POS_PATH, FAISS_FORMS_POS_PATH_E5, 'intfloat/e5-base-v2', ORACC_FORMS_POS_meta_csv_PATH)
        elif model == 'MiniLM':
            return (ORACC_FORMS_POS_embed_csv_PATH, EMBS_FORMS_POS_PATH_MiniLM, IDS_FORMS_POS_PATH, FAISS_FORMS_POS_PATH_MiniLM, 'all-MiniLM-L6-v2', ORACC_FORMS_POS_meta_csv_PATH)
    elif mode == 'signs':
        if model == 'e5':
            return (ORACC_SIGNS_embed_csv_PATH, EMBS_SIGNS_PATH_E5, IDS_SIGNS_PATH, FAISS_SIGNS_PATH_E5, 'intfloat/e5-base-v2', ORACC_SIGNS_meta_csv_PATH)
        elif model == 'MiniLM':
            return (ORACC_SIGNS_embed_csv_PATH, EMBS_SIGNS_PATH_MiniLM, IDS_SIGNS_PATH, FAISS_SIGNS_PATH_MiniLM, 'all-MiniLM-L6-v2', ORACC_SIGNS_meta_csv_PATH)
    elif mode == 'signs_gdl':
        if model == 'e5':
            return (ORACC_SIGNSGDL_embed_csv_PATH, EMBS_SIGNSGDL_PATH_E5, IDS_SIGNSGDL_PATH, FAISS_SIGNSGDL_PATH_E5, 'intfloat/e5-base-v2', ORACC_SIGNSGDL_meta_csv_PATH)
        elif model == 'MiniLM':
            return (ORACC_SIGNSGDL_embed_csv_PATH, EMBS_SIGNSGDL_PATH_MiniLM, IDS_SIGNSGDL_PATH, FAISS_SIGNSGDL_PATH_MiniLM, 'all-MiniLM-L6-v2', ORACC_SIGNSGDL_meta_csv_PATH)
    else:
        raise ValueError("Unknown mode: " + mode)

In [None]:
# for mode in ['forms', 'normalised', 'lemma', 'signs', 'signs_gdl']:
#     paths = select_paths(mode=mode)

#     process_chunks(
#         input_csv_path=paths[0],
#         output_embeddings_path=paths[1],
#         output_ids_path=paths[2],
#         model_name=MODEL_NAME,
#         device="cuda",
#         batch_size=BATCH_SIZE
#     )

# NOTE: this has been run in chunk.py
# for mode in ['forms_POS', 'normalised_POS', 'lemma_POS']:
    
#     paths = select_paths(mode=mode)

#     process_chunks_POS(
#         input_csv_path=paths[0],
#         output_embeddings_path=paths[1],
#         output_ids_path=paths[2],
#         model_name='intfloat/e5-base-v2',
#         device="cuda",
#         batch_size=2048
#     )


In [None]:
def make_FAISS(input_embeddings_path:str, output_faiss_path:str, nlist:int=1024):

    MIN_NLIST, MAX_NLIST = 256, 32768
    TRAIN_MULTIPLIER = 128

    print("Načítám embeddingy…")
    E = np.load(input_embeddings_path).astype("float32")
    N, d = E.shape
    print(f"Vektorů: {N}, dim: {d}")

    # odhad nlist
    nlist = int(4 * (N ** 0.5))
    nlist = max(MIN_NLIST, min(nlist, MAX_NLIST))
    print(f"nlist = {nlist}")

    # kvantizér pro IP (embeddings jsou už normalizované → IP = cosine)
    quantizer = faiss.IndexFlatIP(d)
    index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)

    # vybrat trénovací vzorek
    train_size = min(N, nlist * TRAIN_MULTIPLIER)
    print(f"Tréninkový vzorek: {train_size}")
    rng = np.random.default_rng(42)
    train_idx = rng.choice(N, size=train_size, replace=False)
    train_vecs = E[train_idx]

    print("Trénuji IVF…")
    index.train(train_vecs)
    assert index.is_trained

    print("Přidávám vektory do indexu…")
    index.add(E)   # lze i po dávkách, ale add() si data zkopíruje

    faiss.write_index(index, output_faiss_path)
    print(f"Hotovo. Uložen index: {output_faiss_path}  |  ntotal={index.ntotal}")

In [None]:
for mode in ['normalised', 'normalised_POS', 'lemma', 'lemma_POS', 'forms', 'forms_POS', 'signs', 'signs_gdl']:
    for model in ['e5', 'MiniLM']:
        paths = select_paths(mode=mode, model=model)

        make_FAISS(input_embeddings_path=paths[1], output_faiss_path=paths[3])

In [None]:
# import unicodedata
# from rapidfuzz import fuzz
# from sentence_transformers import CrossEncoder

# INDEX_PATH = FAISS_NORM_POS_PATH
# IDS_PATH   = IDS_NORM_POS_PATH
# META_PATH  = ORACC_NORM_POS_meta_csv_PATH

# MODEL_NAME = "all-MiniLM-L6-v2"


# TOPK = 20
# NPROBE = 64

# # ---- načtení indexu, id mapy a metadat ----
# index = faiss.read_index(INDEX_PATH)
# index.nprobe = NPROBE

# ids = pd.read_csv(IDS_PATH)["chunk_id"].astype(str).tolist()
# meta = pd.read_csv(META_PATH).astype({"chunk_id":"string"})

# device = "cuda"
# model = SentenceTransformer(MODEL_NAME, device=device)

# ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device=device)

In [None]:
import re
from rapidfuzz import fuzz
import unicodedata

# def _strip_diac(s: str) -> str:
#     return "".join(ch for ch in unicodedata.normalize("NFD", s) if unicodedata.category(ch) != "Mn")

def query_to_pos(query: str) -> str:
    out = []
    for tok in query.split():
        # logogramy celé VELKÉ necháme být
        if tok.isupper():
            out.append(tok)
            continue
        # heuristika na jména: aspoň jeden segment začíná velkým písmenem
        segs = tok.split("-")
        if any(seg and seg[0].isalpha() and seg[0].isupper() for seg in segs):
            out.append("PN_RN")
        else:
            out.append(tok)  # zachovat diakritiku i case
    return " ".join(out)


""" E5 ------------------------------------------ """


_POS_BASE = {'an', 'cn', 'dn', 'en', 'fn', 'gn', 'ln', 'mn', 'on', 'pn', 'qn', 'rn', 'sn', 'tn', 'wn', 'yn'}

# def _is_pos_tag_token(tok: str) -> bool:
#     parts = tok.lower().split("_")
#     return all(p in _POS_BASE for p in parts)

# def _tokenize(s: str) -> list[str]:
#     # Unicode slova (vč. diakritiky a underscore), bez čistě číselných tokenů
#     toks = re.findall(r"\w+", str(s), flags=re.UNICODE)
#     return [t.lower() for t in toks if not t.isdigit()]

# def _weighted_overlap(q: str, t: str, tag_w: float = 0.3) -> float:
#     """
#     Vážený překryv tokenů (0..1):
#       - obsahové tokeny (ne-POS) váha 1.0
#       - POS tagy (PN_RN apod.) váha tag_w (např. 0.3)
#     """
#     q_toks = set(_tokenize(q))
#     t_toks = set(_tokenize(t))

#     def w(tok: str) -> float:
#         return tag_w if _is_pos_tag_token(tok) else 1.0

#     # vážený součet přes *dotazové* tokeny
#     denom = sum(w(tok) for tok in q_toks) or 1.0
#     num   = sum(w(tok) for tok in q_toks if tok in t_toks)
#     return num / denom

def _weighted_overlap_POS(q: str, t: str, tag_w: float = 0.3) -> float:
    def is_tag(tok): return all(p in _POS_BASE for p in tok.split("_"))
    def toks(s): 
        ts = re.findall(r"\w+", str(s), flags=re.UNICODE)
        return [t.lower() for t in ts if not t.isdigit()]
    q_toks, t_toks = set(toks(q)), set(toks(t))
    def w(tok): return (tag_w if is_tag(tok) else 1.0)
    denom = sum(w(t) for t in q_toks) or 1.0
    num   = sum(w(t) for t in q_toks if t in t_toks)
    return num / denom

def search_query_e5(query: str, faiss_idx_path:str, ids_path:str, meta_csv_path:str, topk: int = 10, nprobe: int = 256, cand: int = 2000) -> pd.DataFrame:
    # Loading E5 model
    e5 = SentenceTransformer('intfloat/e5-base-v2', device="cuda")
    e5.max_seq_length = 96
    e5 = e5.to(torch.float16)
    torch.set_float32_matmul_precision("high")

    # --- načti POS index/ids/meta postavené na E5 embeddingách ---
    index_pos_e5 = faiss.read_index(faiss_idx_path)
    ids_pos = pd.read_csv(ids_path)["chunk_id"].astype(str).tolist()
    meta_pos = pd.read_csv(meta_csv_path).astype({"chunk_id":"string"})
    
    
    # 1) převede dotaz do POS tvaru, zachová diakritiku/case u obsahu
    q_pos = query_to_pos(query)             # např. "PN_RN rīm tuqumtim"
    # 2) E5: dotaz musí mít prefix "query: "
    q_emb = e5.encode([f"query: {q_pos}"], normalize_embeddings=True).astype("float32")
    # 3) FAISS
    index_pos_e5.nprobe = nprobe
    D, I = index_pos_e5.search(q_emb, max(topk, cand))
    hit_ids = [ids_pos[i] for i in I[0]]
    embed_s = [float(s) for s in D[0]]

    df = pd.DataFrame({"chunk_id": hit_ids, "embed_score": embed_s}).merge(meta_pos, on="chunk_id", how="left")
    
    # malý strukturální signál z POS řetězce (slabá váha)
    q_pos_lower = q_pos.lower()
    df["lex"] = df["text_display"].astype(str).str.lower().apply(lambda t: _weighted_overlap_POS(q_pos_lower, t, tag_w=0.3))
    df["score"] = 0.9*df["embed_score"] + 0.1*df["lex"]

    df = df.drop_duplicates("chunk_id").sort_values("score", ascending=False).head(topk).reset_index(drop=True)
    df.insert(0, "rank", range(1, len(df)+1))
    df["text_display"] = df["text_display"].astype(str).str.slice(0, 180)
    return df[["rank","score","embed_score","lex","chunk_id","text_id","start","end","text_display"]]


""" MINI LM ------------------------------------ """
def _lex_sim(query: str, text: str) -> float:
    q = str(query).lower()
    t = str(text).lower().replace("∎", " ")
    return max(fuzz.partial_ratio(q, t), fuzz.token_set_ratio(q, t)) / 100.0

def search_query_MiniLM(query: str, faiss_idx_path:str, ids_path:str, meta_path:str, topk: int = 10, nprobe: int = 128, cand: int = 1000) -> pd.DataFrame:
    
    mini = SentenceTransformer("all-MiniLM-L6-v2", device='cuda')
    torch.set_float32_matmul_precision("high")

    index_norm = faiss.read_index(faiss_idx_path)
    ids_norm   = pd.read_csv(ids_path)["chunk_id"].astype(str).tolist()
    meta_norm  = pd.read_csv(meta_path).astype({"chunk_id":"string"})
    meta_norm["text_display"] = meta_norm["text_display"].astype(str)

    # 1) FAISS kandidáti
    index_norm.nprobe = nprobe
    q_emb = mini.encode([query], normalize_embeddings=True).astype("float32")
    D, I = index_norm.search(q_emb, max(topk, cand))

    # 2) Poskládat výsledky + meta
    hit_ids  = [ids_norm[i] for i in I[0]]
    embed_s  = [float(s) for s in D[0]]
    df = pd.DataFrame({"chunk_id": hit_ids, "embed_score": embed_s}).drop_duplicates("chunk_id")
    df = df.merge(meta_norm, on="chunk_id", how="left")

    # 3) Lehký lexikální rerank
    df["lex"] = df["text_display"].apply(lambda t: _lex_sim(query, t))
    df["score"] = 0.7 * df["embed_score"] + 0.3 * df["lex"]

    # 4) Řazení a výstup
    df = df.sort_values("score", ascending=False).head(topk).reset_index(drop=True)
    df.insert(0, "rank", range(1, len(df)+1))
    df["text_display"] = df["text_display"].str.slice(0, 180)
    return df[["rank","score","embed_score","lex","chunk_id","text_id","start","end","text_display"]]

In [None]:
full_ORACC_corpus.get_data_by_id('nere/Q009326', mode='lemma')[:6] # Example of getting data by text ID

In [None]:
""" Comparing models and their results. """

results = {}

for mode in ['normalised', 'normalised_POS', 'lemma', 'lemma_POS', 'forms', 'forms_POS', 'signs', 'signs_gdl']:
    query_len = 6
    if mode in ['signs', 'signs_gdl']:
        query_len = 20

    query = ' '.join(full_ORACC_corpus.get_data_by_id('nere/Q009326', mode=mode)[:query_len])
    
    for model in ['e5', 'MiniLM']:
        print(f"Searching for query: {query} in mode: {mode} with model: {model}")
        paths = select_paths(mode=mode, model=model)

        model_name = paths[4]  # model name
        faiss_idx_path = paths[3]  # FAISS index path
        ids_path = paths[2]  # IDs path
        meta_csv_path = paths[5]

        if model == 'e5':
            search_results = search_query_e5(query, faiss_idx_path, ids_path, meta_csv_path)

            results[(query, mode, model)] = search_results

        elif model == 'MiniLM':
            search_results = search_query_MiniLM(query, faiss_idx_path, ids_path, meta_csv_path)

            results[(query, mode, model)] = search_results

        else:
            continue


In [None]:
for res in results:
    query, mode, model = res
    print(f"Results for {mode} with {model}:")
    print(results[res].head(10))  # Display top 10 results
    print("\n")  # New line for better readability

## Direct Intertextuality (with edit distance)

In [None]:
import os

ROOT_PATH = os.getcwd()

DIRECT_INTERTEXT_METADATA_PATH = os.path.join(ROOT_PATH, "directintertext")

In [None]:
from typing import List, Optional
from rapidfuzz.distance import Levenshtein

def token_edit_distance_inner(query: List[str], target: List[str], max_total_ed: Optional[int] = None, unknown_token: str = 'UNKNOWN'):
    m, n = len(query), len(target)
    if m == 0 or m > n:
        return None

    best_sum: Optional[int] = None
    best_hits: List[Tuple[int, int, List[str]]] = []

    for i in range(0, n - m + 1):
        s = 0
        for j in range(m):
            if query[j] == unknown_token:
                d = 0
            else:
                d = Levenshtein.distance(query[j], target[i + j])
            
            s += d
            
            # DŮLEŽITÉ: přeruš jen když je PRŮBĚŽNÁ suma HORŠÍ, ne shodná
            if best_sum is not None and s > best_sum:
                break
        else:
            if best_sum is None or s < best_sum:
                best_sum = s
                best_hits = [(s, i, target[i:i+m])]
            elif s == best_sum:
                best_hits.append((s, i, target[i:i+m]))

    if best_sum is None:
        return None
    if max_total_ed is not None and best_sum > max_total_ed:
        return None

    return best_hits

def token_edit_distance(query: List[str], target: List[str], max_total_ed: Optional[int] = None, unknown_token: str = 'UNKNOWN') -> Optional[int]:
    """ This function serves to search for hits, considering edit distance on the level of full tokens. (e.g., 'inuma ilu ibnu awilutam' // 'inuma blabla ibnu awilutam' has edit distance 1) """
    m, n = len(query), len(target)
    if m == 0 or n == 0:
        return None

    # DP a backtrack
    dp = [[0]*(n+1) for _ in range(m+1)]
    bt = [[0]*(n+1) for _ in range(m+1)]  # 0=diag, 1=up(delete), 2=left(insert)

    for j in range(n+1):
        dp[0][j] = 0                 # substring může začít kdekoliv
    for i in range(1, m+1):
        dp[i][0] = i                 # smazání i tokenů z query
        bt[i][0] = 1

    for i in range(1, m+1):
        ai = query[i-1]
        for j in range(1, n+1):
            bj = target[j-1]
            cost = 0 if (ai == bj or ai == unknown_token) else 1
            del_q = dp[i-1][j] + 1        # delete ai
            ins_q = dp[i][j-1] + 1        # insert bj
            sub  = dp[i-1][j-1] + cost    # match/replace (UNKNOWN matchuje cokoliv za 0)

            if sub <= del_q and sub <= ins_q:
                dp[i][j] = sub; bt[i][j] = 0
            elif del_q <= ins_q:
                dp[i][j] = del_q; bt[i][j] = 1
            else:
                dp[i][j] = ins_q; bt[i][j] = 2

    best = min(dp[m][1:])  # nejlepší vzdálenost přes všechna možná zakončení
    if max_total_ed is not None and best > max_total_ed:
        return None

    hits: List[Tuple[int, int, int, List[str]]] = []
    for j in range(1, n+1):
        if dp[m][j] != best:
            continue
        # backtrack z (m, j) na začátek substringu (řádek i==0)
        i, jj = m, j
        while i > 0:
            move = bt[i][jj]
            if move == 0:      # diag
                i -= 1; jj -= 1
            elif move == 1:    # up (delete v query)
                i -= 1
            else:              # left (insert v targetu)
                jj -= 1
        start, end = jj, j
        hits.append((best, start, end, target[start:end]))

    return hits if hits else None

    # NOTE: before adding unknown token filter:
    # m, n = len(query), len(target)
    # if m == 0 or n == 0:
    #     return None

    # # DP matice (m+1) x (n+1); dp[i][j] = min ed mezi query[:i] a target[:j]
    # # "Substring trick": dovolíme začátek substringu kdekoliv nastavením dp[0][j] = 0 pro všechna j
    # dp = [[0]*(n+1) for _ in range(m+1)]
    # bt = [[0]*(n+1) for _ in range(m+1)]  # backtrack: 0=diag, 1=up(delete in query), 2=left(insert)

    # for j in range(n+1):
    #     dp[0][j] = 0  # klíč: začátek substringu může být kdekoliv
    # for i in range(1, m+1):
    #     dp[i][0] = i  # musíme smazat i tokenů z query, když je target prázdný
    #     bt[i][0] = 1

    # for i in range(1, m+1):
    #     ai = query[i-1]
    #     for j in range(1, n+1):
    #         bj = target[j-1]
    #         cost = 0 if ai == bj else 1
    #         # kandidáti
    #         del_q = dp[i-1][j] + 1     # delete (mažu ai)
    #         ins_q = dp[i][j-1] + 1     # insert (vložím bj)
    #         sub  = dp[i-1][j-1] + cost # match/replace

    #         # min + backpointer
    #         if sub <= del_q and sub <= ins_q:
    #             dp[i][j] = sub
    #             bt[i][j] = 0
    #         elif del_q <= ins_q:
    #             dp[i][j] = del_q
    #             bt[i][j] = 1
    #         else:
    #             dp[i][j] = ins_q
    #             bt[i][j] = 2

    # # nejlepší vzdálenost = min přes koncové pozice j (substring může končit kdekoliv)
    # best = min(dp[m][1:])  # ignorujeme j=0, to by dávalo jen mazání všeho z query
    # if max_total_ed is not None and best > max_total_ed:
    #     return None

    # hits: List[Tuple[int, int, int, List[str]]] = []
    # for j in range(1, n+1):
    #     if dp[m][j] != best:
    #         continue
    #     # backtrack z (m, j) až na řádek i==0; sloupec v tom bodě je start
    #     i, jj = m, j
    #     while i > 0:
    #         move = bt[i][jj]
    #         if move == 0:      # diag
    #             i -= 1
    #             jj -= 1
    #         elif move == 1:    # up (delete)
    #             i -= 1
    #         else:              # left (insert)
    #             jj -= 1
    #     start = jj  # pozice po návratu na i==0 je začátek substringu
    #     end = j     # substring končí na j (exclusive v Python slice)
    #     hits.append((best, start, end, target[start:end]))

    # return hits if hits else None

In [None]:
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Dict, Iterable, Set, Optional, Any

Doc = List[str]

@dataclass
class SimpleIndex:
    postings: Dict[str, List[int]]     # token -> [doc_id]
    doc_unique: List[Set[str]]         # interní doc_id -> unikátní tokeny
    df: Dict[str, int]                 # token -> DF
    N: int                             # počet dokumentů
    ids2ext: List[Any]                 # interní doc_id -> původní ID
    ext2ids: Dict[Any, int]            # původní ID -> interní doc_id

def build_inverted_index(docs: Iterable[Doc], external_ids: Optional[Iterable[Any]] = None, stop: Optional[Set[str]] = ('■')) -> SimpleIndex:
    """
    Vstup: `docs` je iterovatelný přes dokumenty, každý dokument je list tokenů (List[str]).
    Výstup: invertovaný index + základní statistiky.
    """
    stop = stop or set()
    postings_sets: Dict[str, Set[int]] = defaultdict(set)
    doc_unique: List[Set[str]] = []
    ext_ids: List[Any] = []
    ext2int: Dict[Any, int] = {}

    ext_iter = iter(external_ids) if external_ids is not None else None

    for internal_id, tokens in enumerate(docs):
        u = {t for t in tokens if t not in stop}
        doc_unique.append(u)
        for t in u:
            postings_sets[t].add(internal_id)

        if ext_iter is not None:
            ext_id = next(ext_iter)  # vyhodí StopIteration, pokud délky nesedí
        else:
            ext_id = internal_id     # fallback: původní ID = interní index
        ext_ids.append(ext_id)
        ext2int[ext_id] = internal_id

    # finalize
    postings = {t: sorted(ids) for t, ids in postings_sets.items()}
    df = {t: len(ids) for t, ids in postings.items()}
    N = len(doc_unique)

    return SimpleIndex(postings=postings, doc_unique=doc_unique, df=df, N=N,
                       ids2ext=ext_ids, ext2ids=ext2int)

In [None]:
from collections import defaultdict
from math import ceil
from typing import List, Set

def select_documents_for_single_token(index: SimpleIndex, term: str) -> List[int]:
    """Vybere dokumenty obsahující daný termín."""
    return index.postings.get(term, [])

def select_documents_for_tokens(index: SimpleIndex, terms: List[str], benchmark: float = 0.8, stop: Optional[Set[str]] = ('■')) -> Set[int]:
    """
    Vybere dokumenty, které obsahují alespoň benchmark podílu UNIKÁTNÍCH tokenů z dotazu.
    benchmark=0.8 => musí se shodnout aspoň 80 % unikátních dotazových tokenů.
    """
    qset = set(terms)
    if not qset:
        return set()

    if stop:
        qset -= stop

    counts = defaultdict(int)  # doc_id -> kolik dotazových tokenů se našlo
    for t in qset:
        for doc_id in index.postings.get(t, ()):
            counts[doc_id] += 1

    required = ceil(benchmark * len(qset))  # integer práh
    return {doc_id for doc_id, c in counts.items() if c >= required}

In [None]:
def load_data_by_mode(mode:str, oracc_corpus: OraccCorpus):
    if mode == 'normalised':
        normalised_inverted_index = build_inverted_index(oracc_corpus.normalised_corpus, oracc_corpus.texts)
        normalised_stops = set(['■', 'ina', 'ana', 'u', 'ša'])
        return normalised_inverted_index, normalised_stops
    elif mode == 'normalised_POS':
        normalised_pos_inverted_index = build_inverted_index(oracc_corpus.normalised_pos_corpus, oracc_corpus.texts)
        normalised_pos_stops = set(['■', 'ina', 'ana', 'u', 'ša'])
        return normalised_pos_inverted_index, normalised_pos_stops
    elif mode == 'lemma':
        lemma_inverted_index = build_inverted_index(oracc_corpus.lemma_corpus, oracc_corpus.texts)
        lemma_stops = set(['■', 'ina', 'ana', 'u', 'ša', 'i-na', 'a-na'])
        return lemma_inverted_index, lemma_stops
    elif mode == 'lemma_POS':
        lemma_pos_inverted_index = build_inverted_index(oracc_corpus.lemma_pos_corpus, oracc_corpus.texts)
        lemma_pos_stops = set(['■', 'ina', 'ana', 'u', 'ša', 'i-na', 'a-na'])
        return lemma_pos_inverted_index, lemma_pos_stops
    elif mode == 'forms':
        forms_inverted_index = build_inverted_index(oracc_corpus.forms_corpus, oracc_corpus.texts)
        forms_stops = set(['■', 'ina', 'ana', 'u', 'ša', 'i-na', 'a-na'])
        return forms_inverted_index, forms_stops
    elif mode == 'forms_POS':
        forms_pos_inverted_index = build_inverted_index(oracc_corpus.forms_pos_corpus, oracc_corpus.texts)
        forms_pos_stops = set(['■', 'ina', 'ana', 'u', 'ša', 'i-na', 'a-na'])
        return forms_pos_inverted_index, forms_pos_stops
    elif mode == 'signs':
        signs_inverted_index = build_inverted_index(oracc_corpus.signs_corpus, oracc_corpus.texts)
        signs_stops = set(['■'])
        return signs_inverted_index, signs_stops
    elif mode == 'signs_gdl':
        signs_gdl_inverted_index = build_inverted_index(oracc_corpus.signs_gdl_corpus, oracc_corpus.texts)
        signs_gdl_stops = set(['■'])
        return signs_gdl_inverted_index, signs_gdl_stops
    else:
        raise ValueError(f"Unknown mode: {mode}")

def search_for_query_in_target_dataset(mode: str, query: List[str], ORACCtarget_dataset: OraccCorpus, benchmark: float = 0.8) -> Set[int]:
    """
    Prohledá dataset na základě dotazu a v nich hledá intertextualitu. 1) vrátí množinu ID dokumentů, které odpovídají dotazu, 2) aplikuje na ní edit-distance fce.
    """
    target_inverted_idx, stop = load_data_by_mode(mode, ORACCtarget_dataset)
    selected_documents = select_documents_for_tokens(target_inverted_idx, query, stop=stop, benchmark=benchmark)

    hits_inner_ed = {}
    hits_token_ed = {}
    
    for doc_id in selected_documents:
        ORACC_doc_id = target_inverted_idx.ids2ext[doc_id]
        target_data = ORACCtarget_dataset.get_data_by_id(ORACC_doc_id, mode=mode)
        
        hits_inner = token_edit_distance_inner(query, target_data, max_total_ed=10)
        hits_full_tokens = token_edit_distance(query, target_data, max_total_ed=10)

        if hits_inner:
            hits_inner_ed[ORACC_doc_id] = hits_inner
        if hits_full_tokens:
            hits_token_ed[ORACC_doc_id] = hits_full_tokens

    return hits_inner_ed, hits_token_ed


def skip_empty_query(query: List[str], stop: Set[str], min_tokens: int=2) -> bool:
    """
    Check if a query is empty or contains only stop words (or if it is long enough).
    """
    query_tokens = set(query) - stop
    return [(len(query_tokens) < min_tokens), len(query_tokens)]


def search_for_multiple_queries(mode: str, query: List[List[str]], ORACCtarget_dataset: OraccCorpus, benchmark: float=0.8, ignore_texts: List[str] = None, min_tokens: int=2, if_min_tokens_lower_tolerance_to: int=1, remove_empty_hits: bool=False, tolerance_for_inner_ed: int=5, tolerance_for_token_ed: int=2) -> Set[int]:
    """
    ADD description
    """
    print('Creating the index of the target dataset.')
    target_inverted_idx, stop = load_data_by_mode(mode, ORACCtarget_dataset)

    hits_inner_ed_all = {}
    hits_token_ed_all = {}

    for subquery in tqdm(query, desc=f'Searching for intertextualities of {len(query)} queries'):
        selected_documents = select_documents_for_tokens(target_inverted_idx, subquery, stop=stop, benchmark=benchmark)
        # print("Searching for subquery:", subquery, 'in', len(selected_documents), 'documents.')
        hits_inner_ed = {}
        hits_token_ed = {}

        # Skipping empty queries and lowering tolerance with limit queries
        if skip_empty_query(query=subquery, stop=stop, min_tokens=min_tokens)[0]:
            continue
        elif skip_empty_query(query=subquery, stop=stop, min_tokens=min_tokens)[1] == min_tokens:
            tolerance_for_token_ed = if_min_tokens_lower_tolerance_to
        else:
            tolerance_for_token_ed = tolerance_for_token_ed

        for doc_id in selected_documents:
            # Skip texts to be ignored (e.g., the query text itself)
            if ignore_texts and target_inverted_idx.ids2ext[doc_id] in ignore_texts:
                continue

            ORACC_doc_id = target_inverted_idx.ids2ext[doc_id]
            target_data = ORACCtarget_dataset.get_data_by_id(ORACC_doc_id, mode=mode)

            hits_inner = token_edit_distance_inner(subquery, target_data, max_total_ed=tolerance_for_inner_ed)
            hits_full_tokens = token_edit_distance(subquery, target_data, max_total_ed=tolerance_for_token_ed)

            if hits_inner:
                hits_inner_ed[ORACC_doc_id] = hits_inner
            if hits_full_tokens:
                hits_token_ed[ORACC_doc_id] = hits_full_tokens

        hits_inner_ed_all[tuple(subquery)] = hits_inner_ed
        hits_token_ed_all[tuple(subquery)] = hits_token_ed

    if remove_empty_hits:
        hits_inner_ed_all = {k: v for k, v in hits_inner_ed_all.items() if v}
        hits_token_ed_all = {k: v for k, v in hits_token_ed_all.items() if v}

    return hits_inner_ed_all, hits_token_ed_all

In [None]:
def parse_query_text(query:List[str], window_size:int=5, stride:int=3) -> List[str]:
    """
    Parse the input query (list of strings) into strided queries of a specified size.
    """
    if window_size <= 0 or stride <= 0:
        raise ValueError('window_size and stride must be positive integers.')
    n = len(query)
    if n == 0:
        return []
    if n <= window_size:
        return [query[:]] # short query --> one window

    # standard starts by stride
    starts = list(range(0, n - window_size + 1, stride))

    # adding the last window if not present (standard window over the limit)
    last_start = n - window_size
    if not starts or starts[-1] != last_start:
        starts.append(last_start)

    return [query[s:s + window_size] for s in starts]


def get_core_project(text_id: str) -> str:
    parts = text_id.split('/')
    return '/'.join(parts[:len(parts)-1])


def find_intertextualities_of_text(oracc_corpus:OraccCorpus, text_id:str, windown_size:int=5, stride:int=3, mode:str='normalised', benchmark:float=0.8, ignore_itself=True, ignore_core_project=False, tolerance_for_inner_ed=5, tolerance_for_token_ed=2, if_min_tokens_lower_tolerance_to=0):
    """
    Add description
    """
    queries = parse_query_text(oracc_corpus.get_data_by_id(text_id, mode=mode), window_size=windown_size, stride=stride)
    print(f'Input text has been parsed to {len(queries)} queries.')

    ignore_texts = []
    if ignore_itself:
        ignore_texts = [text_id]

    if ignore_core_project:
        print(f'Ignoring texts from the same core project as {text_id}.')
        core_project = get_core_project(text_id)
        if core_project:
            for t_id in oracc_corpus.texts:
                if get_core_project(t_id) == core_project:
                    ignore_texts.append(t_id)

    hits_inner_ed_all, hits_token_ed_all = search_for_multiple_queries(mode=mode, query=queries, ORACCtarget_dataset=oracc_corpus, benchmark=benchmark, ignore_texts=ignore_texts, remove_empty_hits=True, tolerance_for_inner_ed=tolerance_for_inner_ed, tolerance_for_token_ed=tolerance_for_token_ed, if_min_tokens_lower_tolerance_to=if_min_tokens_lower_tolerance_to)
    # TODO: make hits into a class object?

    return hits_inner_ed_all, hits_token_ed_all

In [None]:
hits_inner, hits_tokens = find_intertextualities_of_text(full_ORACC_corpus, 'nere/Q009326', windown_size=20, stride=5, mode='signs', benchmark=0.8, ignore_itself=True, ignore_core_project=True, tolerance_for_inner_ed=10, tolerance_for_token_ed=4, if_min_tokens_lower_tolerance_to=2)
# TODO: involve the interchangeability of signs?

print('Hits based on inner edit distance:')
for hit in hits_inner:
    print(hit, hits_inner[hit])

print('Hits based on token edit distance:')
for hit in hits_tokens:
    print(hit, hits_tokens[hit])