# Try ReFactX

In order to avoid to ingest the full 800-million-facts tree, this notebook uses a small in-memory prefix tree of 31,584 facts about famous artists and directors.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [4]:
import torch
import importlib
import time
from transformers.generation.logits_process import LogitsProcessorList
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from datamodules.base_dataset_config import PROMPT_TEMPLATE
from refactx import ConstrainedLogitsProcessor, ConstrainedStateList, \
                    PatternConstrainedState, DictIndex, patch_model
import refactx

In [5]:
MODEL = 'Qwen/Qwen2.5-3B-Instruct'
#MODEL = 'openai/gpt-oss-20b'
INDEX = '../indexes/simple_index.txt.gz'

In [6]:
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

In [7]:
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(MODEL, device_map='auto')

Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.92s/it]


In [8]:
index = refactx.load_index(INDEX, tokenizer)

100%|██████████| 316/316 [00:01<00:00, 288.60it/s]


In [9]:
streamer = TextStreamer(tokenizer)

In [10]:
question = 'Is Johnny Depp older than Brad Pitt?'

prompted_texts = [refactx.apply_prompt_template(PROMPT_TEMPLATE, tokenizer, question)]

In [11]:
#print(prompted_texts[0])

In [12]:
inputs = tokenizer(prompted_texts, return_tensors='pt', padding=True, padding_side='right')
inputs = inputs.to(model.device)
print(inputs['input_ids'].shape)

torch.Size([1, 756])


In [13]:
model.device

device(type='cuda', index=0)

In [14]:
# no need for num_beams=1
patch_model(model)

In [15]:
refactx.CONSTRAINED_STATES

In [None]:
num_beams = 1

auto_streamer = streamer if num_beams == 1 else None

states = [[PatternConstrainedState(
                pattern = 'Fact:',
                tokenizer = tokenizer,
                cache_index = DictIndex(),
                subtree_cache = DictIndex(),
            )]]

refactx.CONSTRAINED_STATES = ConstrainedStateList(states,
            num_beams=num_beams,
            num_batches = 1,
     )

constrained_processor = ConstrainedLogitsProcessor(
    index=index,
    states=refactx.CONSTRAINED_STATES, tokenizer=tokenizer)
logits_processor_list = LogitsProcessorList([
    constrained_processor
])

model.eval()
start = time.time()

with torch.no_grad():
    out = model.generate(
        **inputs,
        logits_processor=logits_processor_list,
        max_new_tokens=800,
        streamer = auto_streamer,
        do_sample = False,
        temperature = None,
        top_k=None,
        num_beams=num_beams,
        num_return_sequences=num_beams,
        use_cache=True,
        top_p=None,
        min_p=None,
    )

print('Elapsed', time.time() - start)

### Visualize ReFactX output

In [18]:
_from = len(inputs.input_ids[0]) # 0
for i in range(out.shape[0]):
    print('-'*30, sum(out[i][_from:]), len(out[i][_from:]))
    print(tokenizer.decode(out[i][_from:]))

------------------------------ tensor(1058125, device='cuda:0') 247
Reasoning: To determine if Johnny Depp is older than Brad Pitt, I need to find their respective birth dates. Once I have both dates, I can compare them to see which one is older. Let's start with finding their birth dates.
Fact: <Johnny Depp> <date of birth> <1963-06-09T00:00:00Z> . 
Fact: <Brad Pitt> <date of birth> <1963-12-18T00:00:00Z> . 

I found the birth dates of both actors:
- Johnny Depp was born on June 9, 1963.
- Brad Pitt was born on December 18, 1963.

Now let's compare their ages.
- Johnny Depp is 60 years old.
- Brad Pitt is 60 years old.

Since they were born on different days but in the same year, we need to consider the exact day to determine who is older. However, given that the difference in age is only a few days, we can conclude that they are the same age.

Answer: No.<|im_end|>


### Generated Facts

In [None]:
for i, triple in enumerate(refactx.CONSTRAINED_STATES[0][0].generated_triples):
    print(i, tokenizer.decode(triple), end='\n')