## Data

In [None]:
import os



import pandas as pd
from torch.utils.data import DataLoader, Dataset


def row2text_template_scifact(row):
    return f"Title: {row['title']}\nAbstract: {' '.join(row['abstract'])}\nStructured: {row['structured']}\n"


class SciFactDataset(Dataset):
    def __init__(self, queries_path, corpus_path):
        super().__init__()
        queries = pd.read_json(queries_path, lines=True)
        corpus = pd.read_json(corpus_path, lines=True)
        corpus["text"] = corpus.apply(row2text_template_scifact, axis=1)

        self.queries = queries
        self.corpus = corpus
        self.corpus_dict = corpus.set_index("doc_id")["text"].to_dict()

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, i):
        doc_id_list = self.queries["cited_doc_ids"][i]
        query = self.queries["claim"][i]

        docs = [
            self.corpus_dict.get(doc_id)
            for doc_id in doc_id_list
            if doc_id in self.corpus_dict
        ]
        n_docs = len(docs)

        if n_docs == 0:
            return {}
        else:
            queries = [query] * n_docs

        return {
            "doc_id": doc_id_list,
            "query": queries,
            "text": docs,
        }


def scifact_collate_fn(batch):
    batch = [item for item in batch if item]
    if not batch:
        return {}
    doc_ids = sum([item["doc_id"] for item in batch], [])
    queries = sum([item["query"] for item in batch], [])
    texts = sum([item["text"] for item in batch], [])
    return {
        "doc_id": doc_ids,
        "query": queries,
        "text": texts,
    }


def get_scifact_dataloader(
    queries_path,
    corpus_path,
    batch_size=4,
    shuffle=False,
    num_workers=0,
):
    dataset = SciFactDataset(queries_path, corpus_path)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=scifact_collate_fn,
        shuffle=shuffle,
        num_workers=num_workers,
    )
    return dataloader

## Model

In [2]:
import itertools

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"


def get_special_tokens_dict(tokenizer):
    special_tokens_dict = {}
    if tokenizer.pad_token is None:
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
    return special_tokens_dict


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict,
    tokenizer,
    model,
):
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


def get_model_and_tokenizer(
    model_name_or_path,
    model_max_length=1024,
    torch_dtype=torch.bfloat16,
):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        model_max_length=model_max_length,
        padding_side="left",
        use_fast=False,
    )
    special_tokens_dict = get_special_tokens_dict(tokenizer)
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch_dtype,
    )
    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
    )
    return model, tokenizer


def get_genx_transformer(
    query_model_name_or_path,
    doc_model_name_or_path,
    query_model_max_length=128,
    doc_model_max_length=512,
    num_beams: int = 5,
    num_tokens: int = 5,
    torch_dtype=torch.bfloat16,
):
    query_model, query_tokenizer = get_model_and_tokenizer(
        query_model_name_or_path,
        model_max_length=query_model_max_length,
        torch_dtype=torch_dtype,
    )
    doc_model, doc_tokenizer = get_model_and_tokenizer(
        doc_model_name_or_path,
        model_max_length=doc_model_max_length,
        torch_dtype=torch_dtype,
    )
    return GenXTransformer(
        query_model=query_model,
        doc_model=doc_model,
        query_tokenizer=query_tokenizer,
        doc_tokenizer=doc_tokenizer,
        num_beams=num_beams,
        num_tokens=num_tokens,
    )


class GenXTransformer:
    def __init__(
        self,
        query_model,
        doc_model,
        query_tokenizer,
        doc_tokenizer,
        num_beams: int = 5,
        num_tokens: int = 5,
    ):
        super().__init__()
        self.query_model = query_model
        self.doc_model = doc_model

        self.query_tokenizer = query_tokenizer
        self.doc_tokenizer = doc_tokenizer

        self.num_beams = num_beams
        self.num_tokens = num_tokens

        self.verbose = False

        self.config_genx_gen_kwargs(num_beams=num_beams, num_tokens=num_tokens)

    def set_train_eval_mode(self, query_train: bool = True, doc_train: bool = False):
        if query_train:
            self.query_model.train()
        else:
            self.query_model.eval()
        if doc_train:
            self.doc_model.train()
        else:
            self.doc_model.eval()

    def config_genx_gen_kwargs(self, **kwargs):
        gen_kwargs = {
            "max_new_tokens": kwargs.get("max_new_tokens", 5),
            "do_sample": False,
            "num_beams": kwargs.get("num_beams", 5),
            "num_return_sequences": kwargs.get("num_return_sequences", 5),
            "eos_token_id": kwargs.get("eos_token_id", None),
            "pad_token_id": kwargs.get("pad_token_id", None),
        }
        self.genx_gen_kwargs = gen_kwargs

    def index_prompt(self, prompts, model, tokenizer):
        device = model.device

        if isinstance(prompts, str):
            prompts = [prompts]

        batch = tokenizer(
            prompts,
            return_tensors="pt",
            padding="longest",
        )
        batch["input_len"] = len(batch["input_ids"][0])

        genx_gen_kwargs = self.genx_gen_kwargs.copy()
        with torch.no_grad():
            genx_gen_kwargs["input_ids"] = batch["input_ids"].to(device)
            genx_gen_kwargs["attention_mask"] = batch["attention_mask"].to(device)
            generated_tokens = model.generate(**genx_gen_kwargs)

        input_len = batch["input_len"]
        pred_next_tokens = generated_tokens[:, input_len:]
        if self.verbose:
            print(
                "Decoded tokens:",
                tokenizer.batch_decode(
                    pred_next_tokens, skip_special_tokens=False
                ),
            )

        batch_size = len(prompts)
        num_return_sequences = genx_gen_kwargs["num_return_sequences"]

        pred_next_tokens = pred_next_tokens.view(batch_size, num_return_sequences, -1)
        pred_next_tokens = pred_next_tokens.cpu().tolist()

        print("Token IDs:", pred_next_tokens) if self.verbose else None
        return pred_next_tokens

    def index_query(self, prompts: list[str]):
        return self.index_prompt(prompts, self.query_model, self.query_tokenizer)

    def index_doc(self, prompts: list[str]):
        return self.index_prompt(prompts, self.doc_model, self.doc_tokenizer)

    def sample_beams_of_next_tokens(
        self,
        model,
        tokenizer,
        prompts: list[str],
    ) -> list[list[str]]:
        if isinstance(prompts, str):
            prompts = [prompts]

        device = model.device

        batch = tokenizer(
            prompts,
            return_tensors="pt",
            padding="longest",
        )
        batch["input_len"] = len(batch["input_ids"][0])

        gen_kwargs = self.genx_gen_kwargs.copy()

        with torch.no_grad():
            gen_kwargs["input_ids"] = batch["input_ids"].to(device)
            gen_kwargs["attention_mask"] = batch["attention_mask"].to(device)
            generated_tokens = model.generate(**gen_kwargs)

        input_len = batch["input_len"]
        pred_next_tokens = generated_tokens[:, input_len:]
        if self.verbose:
            print(
                "Decoded tokens:",
                tokenizer.batch_decode(pred_next_tokens, skip_special_tokens=False),
            )

        batch_size = len(prompts)
        num_return_sequences = gen_kwargs["num_return_sequences"]

        pred_next_tokens = pred_next_tokens.view(batch_size, num_return_sequences, -1)
        pred_next_tokens = pred_next_tokens.cpu().tolist()

        print("Token IDs:", pred_next_tokens) if self.verbose else None
        return pred_next_tokens

    def get_sft_loss_txt(self, model, tokenizer, prompts: list[str]):
        tokens = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
        tokens = {k: v.to(model.device) for k, v in tokens.items()}

        input_ids = tokens["input_ids"]
        attention_mask = tokens["attention_mask"]
        labels = input_ids.clone()

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        return loss

    def __call__(self, queries: list[str], docs: list[str]):
        # Now this is fine-tuning query model to generate next tokens of document
        assert len(queries) == len(docs)

        beams_for_docs: list[list[str]] = self.sample_beams_of_next_tokens(
            self.doc_model,
            self.doc_tokenizer,
            docs,
        )  # Shape is num_docs x num_beams x num_tokens

        # Shape is num_docs x num_beams x (len(query) + num_tokens)
        prompts_for_all_pairs: list[list[str]] = []
        for doc_idx, beams in enumerate(beams_for_docs):
            beams = self.doc_tokenizer.batch_decode(beams, skip_special_tokens=False)
            prompts = []  # List of the same query and num_beams possible next sentences

            query = queries[doc_idx]
            num_beams = len(beams)
            for beams_idx in range(num_beams):
                prompt = query + beams[beams_idx]
                prompts.append(prompt)

            prompts_for_all_pairs.append(prompts)

        # Have num_docs x num_beams sequences, each of a string of length (len(query) + num_tokens)
        flats: list[str] = list(itertools.chain.from_iterable(prompts_for_all_pairs))

        loss = self.get_sft_loss_txt(self.query_model, self.query_tokenizer, flats)
        return loss

...

Ellipsis

## Store

In [3]:
from abc import ABC, abstractmethod

import numpy as np

class Document:
    def __init__(self, text, metadata):
        self._text = text
        self._metadata = metadata

    def get_text(self):
        return self._text

    def get_metadata(self):
        return self._metadata


class IndexStoreTemplate(ABC):
    def __init__(self, initial_capacity=1000):
        # Capacity and initial capacity
        self._initial_capacity = initial_capacity
        self.capacity = initial_capacity

        # Data and global index of elements
        self.next_id = 0
        self.size = 0
        self._ids = np.zeros(self.capacity, dtype=np.int64)
        self._data_store = {}  # Dictionary to store actual data

    def _resize_if_needed(self, additional_items=16):
        if self.size + additional_items > self.capacity:
            new_capacity = max(self.capacity * 2, self.size + additional_items)

            # Resize ID array
            new_ids = np.zeros(new_capacity, dtype=np.int64)
            new_ids[: self.size] = self._ids[: self.size]
            self._ids = new_ids

            self.capacity = new_capacity

    def _clear_store(self):
        self.capacity = self._initial_capacity
        self.next_id = 0
        self.size = 0
        self._ids = np.zeros(self._initial_capacity, dtype=np.int64)
        self._data_store = {}

    def retrieve(self, doc_id):
        return self._data_store[doc_id]

    @abstractmethod
    def insert(self, text: list[Document]):
        pass

    @abstractmethod
    def query(self, query_text: Document) -> list[list[int]]:
        pass


class PrefixTreeNode:
    def __init__(self):
        self.children = {}
        self.doc_ids = set()


class Prompt:
    def __init__(self, before, after):
        self.before = before
        self.after = after

    def template(self, text):
        return self.before + text + self.after


class SequencePrefixTreeIndexStore(IndexStoreTemplate):
    def __init__(
        self,
        transformer,
        id_len,
        universe,
        doc_prompt_before="Generate identifying phrases that memorize the key concepts in this text. You are not supposed to make sense. Just generate ONLY the identifying phrases without any punctuations or numbers before or after. ",
        doc_prompt_after=" IGNORE ME. Phrases: ",
        query_prompt_before="From this query create identifying phrases that capture the key concepts and align with phrases found in relevant text. Do not aim for meaningful sentences. Only output the identifying phrases with no punctuation or numbers before or after. ",
        query_prompt_after=" IGNORE ME. Phrases: ",
        duplicate_prompt_before="Given the text first remove all the punctuations and stop words. Then shuffle the sentences. Generate some unique related phrases that does not have synonyms. ",
        duplicate_prompt_after=" IGNORE ME. Phrases: ",
        verbose=False,
        initial_capacity=1000,
        insertion_depth=3,
        mode="document_search",
    ):
        super().__init__(initial_capacity)

        assert mode in ["duplicate_detection", "document_search"]
        self.mode = mode
        if mode == "duplicate_detection":
            self.doc_prompt = Prompt(duplicate_prompt_before, duplicate_prompt_after)
            self.query_prompt = Prompt(duplicate_prompt_before, duplicate_prompt_after)
        elif mode == "document_search":
            self.doc_prompt = Prompt(doc_prompt_before, doc_prompt_after)
            self.query_prompt = Prompt(query_prompt_before, query_prompt_after)

        self.doc_prompt_before = doc_prompt_before
        self.doc_prompt_after = doc_prompt_after
        self.query_prompt_before = query_prompt_before
        self.query_prompt_after = query_prompt_after
        self.duplicate_prompt_before = duplicate_prompt_before
        self.duplicate_prompt_after = duplicate_prompt_after

        # Model for generating indices for inserted documens
        self.transformer = transformer
        self.id_len = id_len
        self.universe = set(universe)

        # Verbose
        self.verbose = verbose

        # Prefix tree
        self.root = PrefixTreeNode()
        self.insertion_depth = insertion_depth

    def set_verbose_for_all(self, verbose):
        self.verbose = verbose
        if hasattr(self.transformer, "verbose"):
            self.transformer.verbose = verbose

    def reset_id_len(self, id_len):
        self.id_len = id_len
        self.transformer.update_gen_kwargs(max_new_tokens=id_len)

    def set_mode(self, mode):
        assert mode in ["duplicate_detection", "document_search"]
        self.mode = mode
        if mode == "duplicate_detection":
            self.doc_prompt = Prompt(
                self.duplicate_prompt_before, self.duplicate_prompt_after
            )
            self.query_prompt = Prompt(
                self.duplicate_prompt_before, self.duplicate_prompt_after
            )
        elif mode == "document_search":
            self.doc_prompt = Prompt(self.doc_prompt_before, self.doc_prompt_after)
            self.query_prompt = Prompt(
                self.query_prompt_before, self.query_prompt_after
            )
        print("Remember to call `clear_store` to reset the database!")

    def clear_store(self):
        self.root = PrefixTreeNode()

        super()._clear_store()
        if self.verbose:
            print(f"Store cleared, current capacity: {self.capacity}")

    def _insert_document(self, texts: list[Document], prompt_template):
        if not isinstance(texts, list):
            texts = [texts]

        self._resize_if_needed(len(texts))

        doc_ids = []
        template_texts = []
        for text in texts:
            doc_id = self.next_id
            doc_ids.append(doc_id)
            # Update index in data store
            self.next_id += 1
            self.size += 1

            # Save text in data store
            self._ids[self.size - 1] = doc_id
            self._data_store[doc_id] = text

            template_text = prompt_template(text.get_text())
            template_texts.append(template_text)

        # Generate beams of sequences
        # [batch_size, num_return_sequences, sequence_length]
        lst_of_sequences = self.transformer.index_doc(template_texts)
        print(lst_of_sequences) if self.verbose else None
        self._insert_sequences_into_tree(lst_of_sequences, doc_ids)

    def _insert_sequences_into_tree(self, lst_of_sequences: list[list[list[int]]], doc_ids: list[int]):
        for sequences, doc_id in zip(lst_of_sequences, doc_ids):
            for seq in sequences:
                print(f"Tokens: {seq}") if self.verbose else None
                if len(seq) != self.id_len or not all(x in self.universe for x in seq):
                    continue  # Skip invalid sequences

                self._traverse_and_insert(seq, doc_id)

    def _traverse_and_insert(self, seq, doc_id):
        node = self.root
        depth = 0

        for idx in seq:
            if idx not in node.children:
                node.children[idx] = PrefixTreeNode()
            node = node.children[idx]
            depth += 1
            if depth >= self.insertion_depth:
                node.doc_ids.add(doc_id)
                if self.verbose:
                    print(f"Inserted doc {doc_id} at depth {depth} of prefix tree")

    def insert(self, texts: list[Document]):
        print(f"Inserting '{texts}'") if self.verbose else None
        self._insert_document(texts, self.doc_prompt.template)

    def _query_with_prompt(self, query_texts: list[Document], prompt_template):
        if not query_texts:
            return []

        lst_of_result_ids = []
        template_texts = []
        for query_text in query_texts:
            template_text = prompt_template(query_text.get_text())
            print(template_text) if self.verbose else None
            template_texts.append(template_text)

        # [batch_size, num_return_sequences, sequence_length]
        lst_of_sequences = self.transformer.index_query(template_texts)
        for sequences in lst_of_sequences:
            result_ids = []
            for seq in sequences:
                print(f"Tokens: {seq}") if self.verbose else None
                if len(seq) != self.id_len or not all(x in self.universe for x in seq):
                    continue

                result = self._traverse_tree_for_query(seq)
                if result:
                    result["index_ids"] = seq
                    result["index_txt"] = self.transformer.doc_tokenizer.batch_decode(seq, skip_special_tokens=False)
                    result_ids.append(result)
            print("Found results: ", result_ids) if self.verbose else None
            lst_of_result_ids.append(result_ids)

        return lst_of_result_ids

    def _traverse_tree_for_query(self, seq):
        node: PrefixTreeNode = self.root
        found = True
        depth = 0

        for idx in seq:
            if (idx not in node.children) and (depth < self.insertion_depth):
                found = False
                break
            if (idx not in node.children) and (depth >= self.insertion_depth):
                break
            node = node.children[idx]
            depth += 1

        if found:
            return {"depth": depth, "doc_ids": node.doc_ids}
        return None

    def query(self, query_texts: list[Document]):
        print(f"Querying for '{query_texts}'") if self.verbose else None
        return self._query_with_prompt(query_texts, self.query_prompt.template)

...

Ellipsis

## Test

In [4]:
query_model_name_or_path = "meta-llama/Llama-3.2-1B-Instruct"
query_model_max_length = 128
doc_model_name_or_path = "meta-llama/Llama-3.2-1B-Instruct"
doc_model_max_length = 2048
torch_dtype = torch.bfloat16

query_model, query_tokenizer = get_model_and_tokenizer(
    query_model_name_or_path,
    model_max_length=query_model_max_length,
    torch_dtype=torch_dtype,
)
doc_model, doc_tokenizer = get_model_and_tokenizer(
    doc_model_name_or_path,
    model_max_length=doc_model_max_length,
    torch_dtype=torch_dtype,
)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [5]:
query_model.to("cuda")
doc_model.to("cuda")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128258, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [6]:
genx_transformer = GenXTransformer(
    query_model,
    doc_model,
    query_tokenizer,
    doc_tokenizer,
)

In [7]:
store = SequencePrefixTreeIndexStore(
    genx_transformer,
    id_len=5,
    universe=set(range(genx_transformer.doc_tokenizer.vocab_size)),
    mode="document_search",
    insertion_depth=4,
)
store.clear_store()
store.set_verbose_for_all(False)

In [8]:
scifact_train_dataloader = get_scifact_dataloader(
    "./data/scifact/claims_train.jsonl",
    "./data/scifact/corpus.jsonl",
    batch_size=16,
    shuffle=False,
    num_workers=0,
)

In [9]:
for batch in scifact_train_dataloader:
    # Use dict to automatically handle duplicates (keeps last occurrence)
    unique_docs = {}
    for doc_id, text in zip(batch["doc_id"], batch["text"]):
        unique_docs[doc_id] = text

    docs_to_be_inserted = []
    for doc_id, text in unique_docs.items():
        docs_to_be_inserted.append(Document(text, {"doc_id": doc_id}))

    store.insert(docs_to_be_inserted)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


In [64]:
cited_doc_ids = []
results = []
for batch in scifact_train_dataloader:
    # Use dict to automatically handle duplicates (keeps last occurrence)
    unique_queries = {}
    for query, doc_id in zip(batch["query"], batch["doc_id"]):
        if query not in unique_queries:
            unique_queries[query] = []
        unique_queries[query].append(doc_id)

    queries_to_be_queried = []
    doc_ids_for_a_query = []
    for query, doc_ids in unique_queries.items():
        queries_to_be_queried.append(Document(query, {"doc_ids": doc_ids}))
        doc_ids_for_a_query.append(doc_ids)

    result = store.query(queries_to_be_queried)
    results.extend(result)
    cited_doc_ids.extend(doc_ids_for_a_query)

In [65]:
len(results), len(cited_doc_ids)

(807, 807)

In [66]:
results[0], cited_doc_ids[0]

([], [31715818])

In [76]:
# Count predicted abstracts, correctly predicted abstracts, and gold abstracts
total_predicted = 0
total_correctly_predicted = 0
total_gold = 0

for idx in range(len(results)):
    result = results[idx]
    cited_doc_id = cited_doc_ids[idx]

    # Extract predicted document IDs
    if len(result) > 0:
        temp = []
        for item in result:
            temp.extend(item['doc_ids'])
        predicted = set(temp)
    else:
        predicted = set()

    # Count metrics
    total_predicted += len(predicted)
    total_gold += len(cited_doc_id)

    # Count correctly predicted abstracts (intersection)
    correctly_predicted = predicted.intersection(set(cited_doc_id))
    total_correctly_predicted += len(correctly_predicted)

print(f"Total predicted abstracts: {total_predicted}")
print(f"Total correctly predicted abstracts: {total_correctly_predicted}")
print(f"Total gold abstracts: {total_gold}")

# Calculate precision and recall
if total_predicted > 0:
    precision = total_correctly_predicted / total_predicted
    print(f"Precision: {precision:.4f}")
else:
    print("Precision: 0.0000 (no predictions)")

if total_gold > 0:
    recall = total_correctly_predicted / total_gold
    print(f"Recall: {recall:.4f}")
else:
    print("Recall: 0.0000 (no gold abstracts)")

# Calculate F1 score
if precision + recall > 0:
    f1 = 2 * (precision * recall) / (precision + recall)
    print(f"F1 Score: {f1:.4f}")
else:
    print("F1 Score: 0.0000")

Total predicted abstracts: 5238
Total correctly predicted abstracts: 0
Total gold abstracts: 919
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000


In [12]:
# https://github.com/allenai/scifact/blob/master/doc/evaluation.md#abstract-level-scoring

# Precision: (# correctly predicted abstracts) / (# predicted abstracts)
# Recall: (# correctly predicted abstracts) / (# gold abstracts)