# Try ReFactX

In order to avoid to ingest the full 800-million-facts tree, this notebook uses a small in-memory prefix tree of 31,584 facts about famous artists and directors.

In [2]:
import torch
import importlib
import time
from transformers.generation.logits_process import LogitsProcessorList
from transformers import TextStreamer
from base_dataset_config import PROMPT_TEMPLATE
from ctrie import ConstrainedLogitsProcessor, ConstrainedStateList, \
                    ConstrainedState, DictIndex

In [3]:
MODEL = 'qwen25_3B_model'
INDEX = 'simple_index_qwen'

In [5]:
index_module = importlib.import_module(INDEX)
index_config = getattr(index_module, 'index_config')

model_module = importlib.import_module(MODEL)
model_config = getattr(model_module, 'model_config')

100%|██████████| 316/316 [00:01<00:00, 246.84it/s]


Loading Qwen/Qwen2.5-3B-Instruct


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
streamer = TextStreamer(model_config.tokenizer)

In [7]:
question = 'Is Johnny Depp older than Brad Pitt?'

prompted_texts = [model_config.apply_prompt_template(PROMPT_TEMPLATE, question)]

In [8]:
#print(prompted_texts[0])

In [9]:
inputs = model_config.tokenizer(prompted_texts, return_tensors='pt', padding=True, padding_side='right')
inputs = inputs.to(model_config.model.device)
print(inputs['input_ids'].shape)

torch.Size([1, 756])


In [10]:
model_config.model.device

device(type='cuda', index=1)

In [11]:
num_beams = 1

auto_streamer = streamer if num_beams == 1 else None

states = ConstrainedStateList(
    [ConstrainedState(
                begin_pattern = model_config.switch_pattern,
                end_pattern = model_config.newline_token,
                cache_index = DictIndex(end_of_triple=index_config.index.end_of_triple),
                subtree_cache = DictIndex(end_of_triple=index_config.index.end_of_triple),
                oneleaf_cache = DictIndex(end_of_triple=index_config.index.end_of_triple)
            ) for _ in range(num_beams)],
            num_beams=num_beams,
            batch_size = 1,
            pad_token_id = model_config.tokenizer.eos_token_id)

constrained_processor = ConstrainedLogitsProcessor(
    index=index_config.index,
    end_token=model_config.newline_token, states=states, tokenizer=model_config.tokenizer)
logits_processor_list = LogitsProcessorList([
    constrained_processor
])

model_config.model.eval()
start = time.time()

with torch.no_grad():
    out = model_config.model.generate(
        **inputs,
        logits_processor=logits_processor_list,
        max_new_tokens=800,
        streamer = auto_streamer,
        do_sample = False,
        temperature = None,
        top_k=None,
        num_beams=num_beams,
        num_return_sequences=num_beams,
        use_cache=True,
        top_p=None,
        min_p=None,
        kwargs = {'constrained_state': states}, # passing state
    )

print('Elapsed', time.time() - start)

<|im_start|>system
You are a helpful question-answering assistant that bases its answers on facts from a knowledge base and always respects the prompt.

The process to answer questions:

    You receive an input question.

    You determine the reasoning path needed to answer the question based on the information available.

    You determine the kind of answer you are asked. It can be a yes/no, a single entity, or a list of entities. Pay attention to the questions whose answer is a list of entities (e.g. Which countries share a border with Spain?): you need to find all the answer entities and include them all in the final answer.

    You get relevant facts with the "Fact:" command. You can rely on these facts and use them a proof for your answer.
    While getting facts you continue the reasoning explaining it step by step.

    Often description or short description may be useful for answering questions.

    You conclude with a concise answer that depending on the question can be a

### Visualize ReFactX output

In [12]:
_from = len(inputs.input_ids[0]) # 0
for i in range(out.shape[0]):
    print('-'*30, sum(out[i][_from:]), len(out[i][_from:]))
    print(model_config.tokenizer.decode(out[i][_from:]))

------------------------------ tensor(902001, device='cuda:1') 184
Reasoning: To determine if Johnny Depp is older than Brad Pitt, I need to find their respective birth dates. Once I have both dates, I can compare them to see which one is older. Let's start with finding their birth dates.
Fact: <Johnny Depp> <date of birth> <1963-06-09T00:00:00Z> .
Fact: <Brad Pitt> <date of birth> <1963-12-18T00:00:00Z> .
I found proof that Johnny Depp was born on June 9, 1963, and Brad Pitt was born on December 18, 1963. Now I need to compare these dates to determine who is older.

Answer: Johnny Depp is not older than Brad Pitt.<|im_end|>


### Generated Facts

In [13]:
for i, triple in enumerate(states[0].generated_triples):
    print(i, model_config.tokenizer.decode(triple)[:-1], end='\n')

0  <Johnny Depp> <date of birth> <1963-06-09T00:00:00Z> .
1  <Brad Pitt> <date of birth> <1963-12-18T00:00:00Z> .
