In [None]:
import ir_datasets
import random
import pickle
from tqdm import tqdm

def create_subset_efficient(dataset, sample_percentage, X, seed=None, verbose=False, all_doc_ids=None):
    """
    Cria um subconjunto eficiente do dataset, evitando iteração completa.
    """
    if seed is not None:
        random.seed(seed)
 
    # --- 1) Amostra de Queries ---
    query_list = list(dataset.queries_iter())  # Coletamos todas as queries
    total_queries = len(query_list)
    
    if sample_percentage >= 1:
        sampled_queries = query_list
    else:
        num_to_sample = max(1, int(total_queries * sample_percentage))
        sampled_queries = random.sample(query_list, num_to_sample)

    sampled_query_ids = {q.query_id for q in sampled_queries}

    if verbose:
        print(f"Total queries: {total_queries}, Queries selecionadas: {len(sampled_queries)}")

    # --- 2) Coletar documentos relevantes pelos qrels ---
    qrels_dict = {}  # Mapeia query_id para lista de documentos relevantes
    relevant_doc_ids = set()

    for qrel in dataset.qrels_iter():
        if qrel.query_id in sampled_query_ids:
            if qrel.query_id not in qrels_dict:
                qrels_dict[qrel.query_id] = []
            qrels_dict[qrel.query_id].append(qrel)
            relevant_doc_ids.add(qrel.doc_id)

    # --- 3) Encontrar documentos não relevantes sem iterar tudo ---
    if all_doc_ids is None:
        all_doc_ids = {doc.doc_id for doc in dataset.docs_iter()}  # Obtém todos os IDs disponíveis
    non_relevant_doc_ids = list(all_doc_ids - relevant_doc_ids)  # IDs de docs não relevantes

    # Amostragem de documentos não relevantes
    sampled_non_relevant_docs = random.sample(non_relevant_doc_ids, min(len(non_relevant_doc_ids), X * len(sampled_queries)))

    # --- 4) Recuperar apenas os documentos necessários ---
    doc_store = dataset.docs_store()
    
    relevant_docs = doc_store.get_many(relevant_doc_ids)
    non_relevant_docs = doc_store.get_many(sampled_non_relevant_docs)

    subset_docs = {}
    subset_docs.update(relevant_docs)  # Adiciona os documentos relevantes
    subset_docs.update(non_relevant_docs)  # Adiciona os não relevantes

    if verbose:
        print(f"Docs relevantes: {len(relevant_docs)}, Docs não relevantes: {len(non_relevant_docs)}, Total Docs: {len(subset_docs)}")

    return {q.query_id: q for q in sampled_queries}, subset_docs, [qrel for qrels in qrels_dict.values() for qrel in qrels]


def create_and_save_dataset_efficient(dataset_name, sample_percentage, X, output_file, seed=None, verbose=False, all_doc_ids=None):
    subset_queries_dict, subset_docs, subset_qrels = create_subset_efficient(
        dataset_name, sample_percentage, X, seed, verbose, all_doc_ids
    )
    
    data_to_save = {
        'queries': subset_queries_dict,
        'docs': subset_docs,
        'qrels': subset_qrels
    }
    
    with open(output_file, 'wb') as f:
        pickle.dump(data_to_save, f)
    
    if verbose:
        print(f"Dataset salvo em '{output_file}'.")

def load_dataset(input_file):
    with open(input_file, 'rb') as f:
        return pickle.load(f)

# Exemplo de uso
if __name__ == '__main__':

    dataset_name = "msmarco-passage-v2/train"
    sample_percentage = 0.01
    X = 10
    output_file = "subset_msmarco_train"
    seed = 42
    verbose = True
    dataset = ir_datasets.load(dataset_name)
    
    all_doc_ids = {doc.doc_id for doc in tqdm(dataset.docs_iter()[:10_000_000], desc="Collecting doc IDs")}  # Obtém todos os IDs disponíveis
    print(len(all_doc_ids))

    print("Generating subsets...")
    for pct in [0.01, 0.05, 0.1]:
        for X in [9, 99, 999]:
            create_and_save_dataset_efficient(dataset, pct, X, output_file+"_"+str(pct)+"_"+str(X)+".pkl", seed, verbose, all_doc_ids)

    for pct in [0.01, 0.05, 0.1]:
        for X in [9, 99, 999]:
            output_file = "subset_msmarco_train_"+str(pct)+"_"+str(X)+".pkl"
            dataset_loaded = load_dataset(output_file)
            print("\nDataset carregado:")
            print("Queries:", len(dataset_loaded['queries']))
            print("Docs:", len(dataset_loaded['docs']))
            print("Qrels:", len(dataset_loaded['qrels']))