# Basic usage

This example shows how to use the `genlm-bytes` library for byte-level language modeling.







In [1]:
from genlm.bytes import ByteBeamState, BeamParams
from genlm.backend import load_model_by_name

  from .autonotebook import tqdm as notebook_tqdm


First, load a token-level language model from a huggingface model name. Dependeing on whether CUDA is available, the model will be loaded using either a huggingface (CPU) or vllm (GPU) backend.

In [2]:
llm = load_model_by_name("gpt2-medium")



Initialize a beam state with a maximum beam width of 5.

In [3]:
beam = await ByteBeamState.initial(llm, BeamParams(K=5))

  ).to_sparse_csr()


Populate the beam state with the context. The return value is a new beam state.

In [4]:
beam = await beam.prefill(b"An apple a day keeps the ")
beam

[1mZ: -19.598907929485275
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -19.60: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the[0m[0;32m|␣[0m
([0;32m0.0000[0m) -31.03: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keep|s|␣the[0m[0;32m|␣[0m
([0;32m0.0000[0m) -36.22: [38;5;91m<|endoftext|>|An|␣app|le|␣a|␣day|␣keeps|␣the[0m[0;32m|␣[0m
([0;32m0.0000[0m) -36.49: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th|e[0m[0;32m|␣[0m
([0;32m0.0000[0m) -40.52: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣t|he[0m[0;32m|␣[0m

* Each candidate in the beam corresponds to a sequence of tokens (in purple) and a partial token (in green).
* Each candidate has an associated log weight (the negative numbers in grey), which is the log probability of the sequence of tokens and the partial token.
* The `Z` value corresponds to our estimate of the log partition function, which is the estimate of the prefix probability of the context under the language model.
* Each candidate also has an associated probability (shown on the left in green), which is the weight normalized by the partition function.

We can use the `logp_next` method to get the (log) probability distribution over the next byte.

In [5]:
# Get the log probability distribution over the next byte.
logp_next = await beam.logp_next()
logp_next.pretty().top(5)  # Show the top 5 most probable next bytes

0,1
key,value
b'd',-0.5768002911707057
b'b',-2.8733914084455527
b's',-2.981722712805219
b'w',-3.375940367664043
b'm',-3.5282914648667756


To advance the beam by the next byte, we first prune it to keep only the top 5 candidates, and then use the `<<` operator to feed in the next byte.

In [6]:
new_beam = await (beam.prune() << 100)  # 100 is the byte value of 'd'
new_beam

[1mZ: -20.17567801749765
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -20.18: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the[0m[0;32m|␣d[0m
([0;32m0.0000[0m) -31.93: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keep|s|␣the[0m[0;32m|␣d[0m
([0;32m0.0000[0m) -38.71: [38;5;91m<|endoftext|>|An|␣app|le|␣a|␣day|␣keeps|␣the[0m[0;32m|␣d[0m
([0;32m0.0000[0m) -39.28: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th|e[0m[0;32m|␣d[0m
([0;32m0.0000[0m) -40.25: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣[0m[0;32m|d[0m
([0;32m0.0000[0m) -43.16: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣t|he[0m[0;32m|␣d[0m
([0;32m0.0000[0m) -51.34: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keep|s|␣the|␣[0m[0;32m|d[0m
([0;32m0.0000[0m) -54.79: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th|e|␣[0m[0;32m|d[0m
([0;32m0.0000[0m) -56.64: [38;5;91m<|endoftext|>|An|␣app|le|␣a|␣day|␣keeps|␣the|␣[0m[0;32m|d[0m
([0;32m0.0000[0m) -58.94: [38;5;91m<

Since extending the beam by one byte can grow the number of candidates, we can again prune it to keep only the top 5 candidates:

In [7]:
pruned_beam = new_beam.prune()
pruned_beam

[1mZ: -20.175678017602173
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -20.18: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the[0m[0;32m|␣d[0m
([0;32m0.0000[0m) -31.93: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keep|s|␣the[0m[0;32m|␣d[0m
([0;32m0.0000[0m) -38.71: [38;5;91m<|endoftext|>|An|␣app|le|␣a|␣day|␣keeps|␣the[0m[0;32m|␣d[0m
([0;32m0.0000[0m) -39.28: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣th|e[0m[0;32m|␣d[0m
([0;32m0.0000[0m) -40.25: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the|␣[0m[0;32m|d[0m

We can further speed up the algorithm with more a more aggressive pruning strategy. 

In particular, `BeamParams` has a `prune_threshold` parameter which controls the minimum probability that a candidate must have to be kept in the beam. Higher values lead to more aggressive pruning, which significantly reduces the number of language model calls we need to make.


In [8]:
beam = await ByteBeamState.initial(llm, BeamParams(K=5, prune_threshold=0.05))
beam = await beam.prefill(b"An apple a day keeps the ")
beam

[1mZ: -19.598918914794922
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -19.60: [38;5;91m<|endoftext|>|An|␣apple|␣a|␣day|␣keeps|␣the[0m[0;32m|␣[0m

In [9]:
logp_next = await beam.logp_next()
logp_next.pretty().top(5)

0,1
key,value
b'd',-0.5766762743944795
b'b',-2.8732729803080233
b's',-2.9816068063730867
b'w',-3.3758250127787264
b'm',-3.528177345847574


Putting it all together, we can generate a sequence of bytes by repeatedly selecting a next byte from the log probability distribution and advancing the beam by that byte. 

One selection strategy is to always select the byte with the highest log probability, which is what `greedy` does:

In [10]:
beam = await ByteBeamState.initial(
    llm, BeamParams(K=5, prune_threshold=0.05, verbose=True)
)
sampled = await beam.greedy(b"An apple a day keeps the ", steps=12)
sampled


[1mZ: -2.174436330795288
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -2.17: [38;5;91m<|endoftext|>[0m[0;32m|A[0m


[1mZ: -4.501037198697343
[0m[1mCandidates:
[0m([0;32m0.9977[0m) -4.50: [38;5;91m<|endoftext|>[0m[0;32m|An[0m
([0;31m0.0023[0m) -10.56: [38;5;91m<|endoftext|>|A[0m[0;32m|n[0m


[1mZ: -5.643285751342773
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -5.64: [38;5;91m<|endoftext|>|An[0m[0;32m|␣[0m


[1mZ: -7.201362133026123
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -7.20: [38;5;91m<|endoftext|>|An[0m[0;32m|␣a[0m


[1mZ: -10.39808464050293
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -10.40: [38;5;91m<|endoftext|>|An[0m[0;32m|␣ap[0m


[1mZ: -10.627063751220703
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -10.63: [38;5;91m<|endoftext|>|An[0m[0;32m|␣app[0m


[1mZ: -12.216539396150903
[0m[1mCandidates:
[0m([0;32m0.9934[0m) -12.22: [38;5;91m<|endoftext|>|An[0m[0;32m|␣appl[0m
([0;31m0.0066[0m) -17.23: [38;5;91m<|endoftext|>|An|␣app

b'An apple a day keeps the doctor away.'

We can also sample from the log probability distribution over the next byte:

In [11]:
beam = await ByteBeamState.initial(
    llm, BeamParams(K=5, prune_threshold=0.05, verbose=True)
)
sampled = await beam.sample(b"An apple a day keeps the ", steps=12)
sampled


[1mZ: -2.174436330795288
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -2.17: [38;5;91m<|endoftext|>[0m[0;32m|A[0m


[1mZ: -4.501037198697343
[0m[1mCandidates:
[0m([0;32m0.9977[0m) -4.50: [38;5;91m<|endoftext|>[0m[0;32m|An[0m
([0;31m0.0023[0m) -10.56: [38;5;91m<|endoftext|>|A[0m[0;32m|n[0m


[1mZ: -5.643285751342773
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -5.64: [38;5;91m<|endoftext|>|An[0m[0;32m|␣[0m


[1mZ: -7.201362133026123
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -7.20: [38;5;91m<|endoftext|>|An[0m[0;32m|␣a[0m


[1mZ: -10.39808464050293
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -10.40: [38;5;91m<|endoftext|>|An[0m[0;32m|␣ap[0m


[1mZ: -10.627063751220703
[0m[1mCandidates:
[0m([0;32m1.0000[0m) -10.63: [38;5;91m<|endoftext|>|An[0m[0;32m|␣app[0m


[1mZ: -12.216539396150903
[0m[1mCandidates:
[0m([0;32m0.9934[0m) -12.22: [38;5;91m<|endoftext|>|An[0m[0;32m|␣appl[0m
([0;31m0.0066[0m) -17.23: [38;5;91m<|endoftext|>|An|␣app

b'An apple a day keeps the Dutch away.\n'