## Installation

- Check your system and make sure it satisfies the following requirements:

  - This package should work on any Linux distribution. Sorry no MacOS or Windows support :)
  - Supported architectures are x86_64 and i686.
  - Your system needs to be little-endian. This should be the case for most modern machines.
  - Please make sure you have Python >=3.8 (and strictly speaking, CPython, not PyPy or some other implementations).

- Install this package: pip install infini-gram

- Download the infini-gram index that you would like to query. For sake of performance, it is strongly recommended that you put the index on an SSD. See details in the `Pre-built Indexes` section below.

## Pre-built Indexes

We have made the following indexes publicly available on AWS S3.

Smaller indexes are stored in the `<s3://infini-gram-lite>` bucket and can be downloaded for free and without an AWS account. These indexes are `v4_pileval_llama`, `v4_pileval_gpt2`, and `v4_dolmasample_olmo`. To download, run command:

```bash
aws s3 cp --no-sign-request --recursive {S3_URL} {LOCAL_INDEX_PATH}
```

## Usage

Prior to submitting any type of queries, you need to instatiate the engine with the index you would like to query. As an example, below we create an engine with the index for Pile-val (the validation set of Pile), which was created using the Llama-2 tokenizer.

Let's load our tokenizer. The tokenizer should match that of the index you load
for the Infini-gram AR model.

In [93]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_bos_token=False, add_eos_token=False)

Now let's load our Infini-gram model using the `tokenizer`.

In [94]:
from infini_gram.engine import InfiniGramEngine
 # replace `index_dir` with the local directory where you store the index
engine = InfiniGramEngine(index_dir='index/v4_pileval_llama',
                          eos_token_id=tokenizer.eos_token_id)

## Count an n-gram (count() and count_cnf())

This query type counts the number of times the query string appears in the corpus. For example, to find out the number of occurrences of n-gram natural language processing in the Pile-val corpus,

In [95]:
input_ids = tokenizer.encode('natural language processing')
print(input_ids)
# [5613, 4086, 9068]

[5613, 4086, 9068]


In [96]:
out = engine.count(input_ids=input_ids)
print(out)
# {'count': 76, 'approx': False}

{'count': 76, 'approx': False}


The `approx` field indicates whether the count is approximate. For simple queries with a single n-gram term, this is always False (the count is always exact). As you will see later, count for complex queries may be approximate.

### Empty Query

If you submit an empty query, the engine returns the total number of tokens in the corpus.
The empty query is just `[]`:

In [97]:
print(engine.count(input_ids=[]))
#{'count': 393769120, 'approx': False}

{'count': 393769120, 'approx': False}


You can also make more complex queries by connecting multiple n-grams with the `AND` and `OR` operators, in the CNF format, in which case the engine counts the number of times where this logical constraint is satisfied. A CNF query is a triply-nested list. The top-level is a list of disjunctive clauses (which are eventually connected with the AND operator). Each disjuctive clause is a list of n-gram terms (which are eventually connected with the OR operator). And each n-gram term has the same format as input_ids above, i.e., a list of token ids.

In [98]:

# natural language processing OR artificial intelligence
cnf = [
    [tokenizer.encode('natural language processing'), tokenizer.encode('artificial intelligence')]
]
print(cnf)
# [[[5613, 4086, 9068], [23116, 21082]]]

print(engine.count_cnf(cnf=cnf))
# {'count': 499, 'approx': False}



[[[5613, 4086, 9068], [23116, 21082]]]
{'count': 499, 'approx': False}


In [99]:
def tokenize_cnf(json):
    cnf = []
    for disj in json:
        cnf.append([tokenizer.encode(phrase) for phrase in disj])
    return cnf

In [100]:
print(tokenize_cnf([['natural language processing', 'artificial intelligence']]))

[[[5613, 4086, 9068], [23116, 21082]]]


In [101]:
# (natural language processing OR artificial intelligence) AND deep learning
cnf = [
   [tokenizer.encode('natural language processing'), tokenizer.encode('artificial intelligence')],
   [tokenizer.encode('deep learning')]
]
print(cnf)
# [[[5613, 4086, 9068], [23116, 21082]], [[6483, 6509]]]

print(engine.count_cnf(cnf=cnf))
#{'count': 19, 'approx': False}


[[[5613, 4086, 9068], [23116, 21082]], [[6483, 6509]]]
{'count': 19, 'approx': False}


In [102]:
engine.count_cnf(tokenize_cnf([['natural language processing', 'artificial intelligence'], ['deep learning']]))

{'count': 19, 'approx': False}

## Probability of an n-gram

This query type computes the n-gram LM probability of a token conditioning on a preceding prompt.

For example, to compute P(processing | natural language):

In [103]:
input_ids = tokenizer.encode('natural language processing')
print(input_ids)
# [5613, 4086, 9068]

natural_id = tokenizer.encode('natural')[-1]
language_id = tokenizer.encode('language')[-1]
processing_id = tokenizer.encode('processing')[-1]


print(engine.prob(prompt_ids=input_ids[:-1], cont_id=input_ids[-1]))
#{'prompt_cnt': 257, 'cont_cnt': 76, 'prob': 0.29571984435797666}

res = engine.prob(prompt_ids=[natural_id, language_id], cont_id=processing_id)
print(res)





[5613, 4086, 9068]
{'prompt_cnt': 257, 'cont_cnt': 76, 'prob': 0.29571984435797666}
{'prompt_cnt': 257, 'cont_cnt': 76, 'prob': 0.29571984435797666}


Let's make a helper function for computing joint probabilities of n-grams.
`engine.prob()` returns the probability of the last token in the query given the preceding tokens. The query is a list of token ids. So: Prob[d c | a b] = Prob[d | a b c] * Prob[c | a b].
In code:

In [104]:
def compute_conditional_prob(prompt_ids, cont_ids):
    """
    Compute the conditional probability of a continuation sequence given a prompt sequence

    Example:

    ```python
    natural_id, language_id, processing_id, is_id, fun_id =
        tokenizer.encode('natural language processing is fun')
    compute_conditional_prob([natural_id, language_id], [processing_id, is_id, fun_id])


    :param prompt_ids: list of token ids representing the prompt sequence
    :param cont_ids: list of token ids representing the continuation sequence
    :return: the conditional probability of the continuation sequence given the prompt sequence
    """
    prob = 1
    orig = prompt_ids.copy()
    for tok in cont_ids:
        p = engine.prob(prompt_ids=prompt_ids, cont_id=tok)['prob']
        prob *= p
        prompt_ids.append(tok)
    return {'prompt_cnt': orig, 'cont_cnt': cont_ids, 'prob': prob}


In [105]:
natural_id, language_id, processing_id = tokenizer.encode('natural language processing')
print([natural_id, language_id, processing_id])
res = compute_conditional_prob([], [natural_id, language_id, processing_id])
print(res)

[5613, 4086, 9068]
{'prompt_cnt': [], 'cont_cnt': [5613, 4086, 9068], 'prob': 1.930065008652786e-07}


## Next-token distribution (ntd())

This query type computes the n-gram LM next-token distribution conditioning on a preceding prompt.

For example, this will return the token distribution following natural language:

In [115]:
input_ids = tokenizer.encode('natural language')
print(input_ids)
# [5613, 4086]

# engine.ntd(prompt_ids=input_ids)
# {'prompt_cnt': 257, 'result_by_token_id':
# {13: {'cont_cnt': 1, 'prob': 0.0038910505836575876},
# 297: {'cont_cnt': 1, 'prob': 0.0038910505836575876},
# ...,
# 30003: 'cont_cnt': 1, 'prob': 0.0038910505836575876}}, 'approx': False}



[5613, 4086]


`result_by_token_id` is a `dict` that maps token id to the probability of that
token as a continuation of the prompt.

If the prompt cannot be found in the corpus, you will get an empty distribution:

In [118]:

input_ids = tokenizer.encode('I love natural language processing')
print(input_ids)
# [306, 5360, 5613, 4086, 9068]

print(engine.ntd(prompt_ids=input_ids[:-1]))
# {'prompt_cnt': 0, 'result_by_token_id': {}, 'approx': False}

[306, 5360, 5613, 4086, 9068]
{'prompt_cnt': 0, 'result_by_token_id': {}, 'approx': False}


In [131]:
input_ids = tokenizer.encode('hello world')
print(engine.infgram_prob(prompt_ids=input_ids[:-1], cont_id=input_ids[-1]))
#{'prompt_cnt': 257, 'cont_cnt': 76, 'prob': 0.29571984435797666, 'suffix_len': 2}

{'prompt_cnt': 1421, 'cont_cnt': 47, 'prob': 0.033075299085151305, 'suffix_len': 1}


In [132]:
result = engine.infgram_ntd(prompt_ids=input_ids, max_support=10)

In [133]:
tok_ids = result['result_by_token_id'].keys()
print(tok_ids)

dict_keys([13, 322, 1342, 1824, 29871, 29908])


In [134]:
tokenizer.convert_ids_to_tokens(tok_ids)

['<0x0A>', '▁and', '▁example', '▁program', '▁', '"']

In [135]:
input_ids = tokenizer.encode('natural language processing')
print(input_ids)
# [5613, 4086, 9068]

engine.search_docs(input_ids=input_ids, maxnum=1, max_disp_len=10)

# {'cnt': 76, 'approx': False, 'idxs': [54], 'documents': [{'doc_ix': 142405, 'doc_len': 19238, 'disp_len': 10, 'metadata': '', 'token_ids': [4475, 304, 9045, 2562, 322, 5613, 4086, 9068, 29889, 13]}]}

[5613, 4086, 9068]


{'cnt': 76,
 'approx': False,
 'idxs': [69],
 'documents': [{'doc_ix': 109291,
   'doc_len': 521,
   'disp_len': 10,
   'metadata': '{"path": "val.jsonl", "linenum": 109291, "metadata": {"meta": {"pile_set_name": "Wikipedia (en)"}}}\n',
   'token_ids': [14881,
    9608,
    29879,
    29889,
    512,
    5613,
    4086,
    9068,
    29892,
    670]}]}

In [136]:
# natural language processing AND deep learning
cnf = [
     [tokenizer.encode('natural language processing')],
     [tokenizer.encode('deep learning')],
]
print(cnf)
# [[[5613, 4086, 9068]], [[6483, 6509]]]

engine.search_docs_cnf(cnf=cnf, maxnum=1, max_disp_len=20)
# {'cnt': 6, 'approx': False, 'idxs': [2], 'documents': [{'doc_ix': 191568, 'doc_len': 3171, 'disp_len': 20, 'metadata': '', 'token_ids': [29889, 450, 1034, 13364, 508, 367, 4340, 1304, 304, 7945, 6483, 6509, 2729, 5613, 4086, 9068, 9595, 1316, 408, 10013]}]}


[[[5613, 4086, 9068]], [[6483, 6509]]]


{'cnt': 6,
 'approx': False,
 'idxs': [4],
 'documents': [{'doc_ix': 68215,
   'doc_len': 1527,
   'disp_len': 20,
   'metadata': '{"path": "val.jsonl", "linenum": 68215, "metadata": {"meta": {"pile_set_name": "OpenWebText2"}}}\n',
   'token_ids': [607,
    7199,
    4637,
    2211,
    1667,
    7117,
    29901,
    4933,
    6509,
    29892,
    5613,
    4086,
    9068,
    29892,
    322,
    3061,
    4564,
    4110,
    297,
    427]}]}

Again, you can also use max_clause_freq and max_diff_tokens to control the behavior of CNF

In [143]:
tok_ids = engine.get_doc_by_rank(s=0, rank=365362993, max_disp_len=1000)['token_ids']
print("".join(tokenizer.convert_ids_to_tokens(tok_ids)).replace('▁', ' '))

