# Try ReFactX

In order to avoid to ingest the full 800-million-facts tree, this notebook uses a small in-memory prefix tree of 31,584 facts about famous artists and directors.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')

In [None]:
import torch
import importlib
import time
from transformers.generation.logits_process import LogitsProcessorList
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from datamodules.base_dataset_config import PROMPT_TEMPLATE
from refactx import ConstrainedLogitsProcessor, ConstrainedStateList, \
                    PatternConstrainedState, DictIndex, patch_model
import refactx

In [None]:
MODEL = 'Qwen/Qwen2.5-3B-Instruct'
#MODEL = 'openai/gpt-oss-20b'
INDEX = '../indexes/simple_index.txt.gz'

In [None]:
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL)
#model = AutoModelForCausalLM.from_pretrained(MODEL, device_map='auto')

In [None]:
import gzip
with gzip.open(INDEX) as reader:
    refactx.populate_postgres_index(reader,
                                    #'postgres://secondment:ofa3eebohgh6chioqu9Aep9maev6eejothith5bot4iuqu3oge7doo8uoCe0ooda@10.0.0.118:5432/postgres',
                                    'postgres://postgres:vipez3loh4pah2ahS1aefohy5aiLoh2fooxo0ke1ahw3aiphier8gei6aith6iof@10.0.0.118:5432/postgres',
                                    tokenizer,
                                    'testinterpopulate',
                                    batch_size=5000,
                                    rootkey = -100,
                                    configkey=-200,
                                    switch_parameter = 7,
                                    total_number_of_triples=None,
                                    prefix='',
                                    tokenizer_batch_size=5000,
                                    add_special_tokens=False,
                                    count_leaves=True,
                                    debug=False)
   

In [None]:
index = refactx.load_index(
    'postgres://secondment:ofa3eebohgh6chioqu9Aep9maev6eejothith5bot4iuqu3oge7doo8uoCe0ooda@10.0.0.118:5432/postgres?tablename=testinterpopulate', 
    #tokenizer,
    #configkey=-200,
    cache='default'
)

In [None]:
index.get_config()

In [None]:
streamer = TextStreamer(tokenizer)

In [None]:
question = 'Is Johnny Depp older than Brad Pitt?'

prompted_texts = [refactx.apply_prompt_template(PROMPT_TEMPLATE, tokenizer, question)]

In [None]:
#print(prompted_texts[0])

In [None]:
inputs = tokenizer(prompted_texts, return_tensors='pt', padding=True, padding_side='right')
inputs = inputs.to(model.device)
print(inputs['input_ids'].shape)

In [None]:
model.device

In [None]:
# no need for num_beams=1
patch_model(model)

In [None]:
refactx.CONSTRAINED_STATES

In [None]:
num_beams = 1

auto_streamer = streamer if num_beams == 1 else None

states = [[PatternConstrainedState(
                pattern = 'Fact:',
                tokenizer = tokenizer,
                cache_index = DictIndex(),
                subtree_cache = DictIndex(),
            )]]

refactx.CONSTRAINED_STATES = ConstrainedStateList(states,
            num_beams=num_beams,
            num_batches = 1,
     )

constrained_processor = ConstrainedLogitsProcessor(
    index=index,
    states=refactx.CONSTRAINED_STATES, tokenizer=tokenizer)
logits_processor_list = LogitsProcessorList([
    constrained_processor
])

model.eval()
start = time.time()

with torch.no_grad():
    out = model.generate(
        **inputs,
        logits_processor=logits_processor_list,
        max_new_tokens=800,
        streamer = auto_streamer,
        do_sample = False,
        temperature = None,
        top_k=None,
        num_beams=num_beams,
        num_return_sequences=num_beams,
        use_cache=True,
        top_p=None,
        min_p=None,
    )

print('Elapsed', time.time() - start)

In [None]:
index.cache

### Visualize ReFactX output

In [None]:
_from = len(inputs.input_ids[0]) # 0
for i in range(out.shape[0]):
    print('-'*30, sum(out[i][_from:]), len(out[i][_from:]))
    print(tokenizer.decode(out[i][_from:]))

### Generated Facts

In [None]:
for i, triple in enumerate(refactx.CONSTRAINED_STATES[0][0].generated_triples):
    print(i, tokenizer.decode(triple), end='\n')