In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch


def get_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"


device = torch.device(get_device())
print(f"Using device: {device}")

Using device: cpu


  return torch._C._cuda_getDeviceCount() > 0


In [3]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")


def generate_n_tokens(
    input_ids: torch.Tensor, n: int, sampling_function: callable
) -> torch.Tensor:
    generated = input_ids.clone()
    for _ in range(n):
        with torch.no_grad():
            logits = model(generated).logits[:, -1, :]
        next_token = sampling_function(logits)
        generated = torch.cat([generated, next_token.unsqueeze(-1)], dim=-1)
    return generated

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Sample vocabulary
sample_vocab = [
    "token1",
    "token2",
    "token3",
    "token4",
    "token5",
    "token6",
    "token7",
    "token8",
    "token9",
    "token10",
]
vocabulary_size = len(sample_vocab)

# Sample logits
sample_logits = torch.tensor(
    [
        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
        [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
        [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
        [1.0, 1.0, 1.0, 1.0, 10.0, 1.0, 1.0, 1.0, 1.0, 1.0],
    ]
)


# Function to convert token indices to vocabulary tokens
def indices_to_tokens(indices):
    return [sample_vocab[i] for i in indices]

In [5]:
from stochastic_decoding import greedy_search

# Test greedy search
greedy_results = greedy_search(sample_logits)
print("Greedy Search Results:", indices_to_tokens(greedy_results))

Greedy Search Results: ['token10', 'token1', 'token1', 'token5']


Greedy Search should always take the highest value logits in each sequnce, therefore you should get:

```python
Greedy Search Results: ['token10', 'token1', 'token1', 'token5']
```

In [6]:
from stochastic_decoding import top_k_sampling, sample_from_logits

# Test top-k sampling
k = 1
top_k_logits = top_k_sampling(sample_logits, k)
top_k_results = sample_from_logits(top_k_logits)
print(f"Top-{k} Sampling Results:", indices_to_tokens(top_k_results))
k = 3
top_k_logits = top_k_sampling(sample_logits, k)
top_k_results = sample_from_logits(top_k_logits)
print(f"Top-{k} Sampling Results:", indices_to_tokens(top_k_results))

Top-1 Sampling Results: ['token10', 'token1', 'token1', 'token5']
Top-3 Sampling Results: ['token8', 'token1', 'token3', 'token5']


With a k of 1 top k devolves into greedy hence you should get:

```python
Top-1 Sampling Results: ['token10', 'token1', 'token1', 'token5']
```

When k is 3 there will be a little more variation but it will likely be that the first token is 10, second 1, the last is 5, and the third is random. Why do you think that is?

In [7]:
from stochastic_decoding import top_p_sampling

# Test top-p sampling
p = 0.05
top_p_logits = top_p_sampling(sample_logits, p)
top_p_results = sample_from_logits(top_p_logits)
print(f"Top-p Sampling Results (p={p}):", indices_to_tokens(top_p_results))
p = 0.9
top_p_logits = top_p_sampling(sample_logits, p)
top_p_results = sample_from_logits(top_p_logits)
print(f"Top-p Sampling Results (p={p}):", indices_to_tokens(top_p_results))

Top-p Sampling Results (p=0.05): ['token10', 'token1', 'token1', 'token5']
Top-p Sampling Results (p=0.9): ['token10', 'token1', 'token6', 'token5']


In the first example we sample the top 5% of logits, since there are only 10 this gives us the top 1 logit, which means that we basically have reduced this to a greedy search (note this isn't true for the last token since it all has equal probability), so I got:
```python
Top-p Sampling Results (p=0.1): ['token10', 'token1', 'token1', 'token5']
```
In the second example we take the top 90% of logits, thus we remove one logit from the pool and sample from the remaning so your output will vary but it should have the first token is 10, second is 1, fourth is 5 and, the third is random.

In [8]:
from stochastic_decoding import temperature_sampling

# Test temperature sampling
temperature = 0.1
temp_logits = temperature_sampling(sample_logits, temperature)
temp_results = sample_from_logits(temp_logits)
print(
    f"Temperature Sampling Results (T={temperature}):", indices_to_tokens(temp_results)
)
temperature = 5
temp_logits = temperature_sampling(sample_logits, temperature)
temp_results = sample_from_logits(temp_logits)
print(
    f"Temperature Sampling Results (T={temperature}):", indices_to_tokens(temp_results)
)

Temperature Sampling Results (T=0.1): ['token10', 'token1', 'token8', 'token5']
Temperature Sampling Results (T=5): ['token5', 'token1', 'token8', 'token4']


Since a temprature value of less than 1 makes the highest probability logit increase in probability and reduces the rest, at a very small temprature it degenerates into a greedy search. Thus you should get the the first, second, and fourth token are the same as greedy. Note that since all logits for the third token have equal probability it will give a random logit for it.

```python
Temperature Sampling Results (T=0.1): ['token10', 'token1', 'token5', 'token5']
```

Note that since a temprature greater than 1 flattens the disribution all tokens become more likely so its a bit more random (this is sometimes referred to as the "creativity" of the model)

In [9]:
# Generate n tokens using different sampling strategies
n_tokens = 40

# Prepare input
text = "Once upon a time, there was a"
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)

greedy_output = generate_n_tokens(input_ids, n_tokens, greedy_search)
top_k_output = generate_n_tokens(
    input_ids, n_tokens, lambda x: sample_from_logits(top_k_sampling(x, k=5))
)
top_p_output = generate_n_tokens(
    input_ids, n_tokens, lambda x: sample_from_logits(top_p_sampling(x, p=0.05))
)
temp_output = generate_n_tokens(
    input_ids,
    n_tokens,
    lambda x: sample_from_logits(temperature_sampling(x, temperature=1.5)),
)

# Decode outputs
print("Greedy:", tokenizer.decode(greedy_output[0], clean_up_tokenization_spaces=True))
print("Top-k:", tokenizer.decode(top_k_output[0], clean_up_tokenization_spaces=True))
print("Top-p:", tokenizer.decode(top_p_output[0], clean_up_tokenization_spaces=True))
print(
    "Temperature:", tokenizer.decode(temp_output[0], clean_up_tokenization_spaces=True)
)

Greedy: Once upon a time, there was a man who was a man of great wealth and power. He was a man of great wealth and power. He was a man of great wealth and power. He was a man of great wealth and power
Top-k: Once upon a time, there was a time where I was a little bit of a recluse. I would go into an apartment and sit at a desk and read, and then, as I read, I would sit at my desk and
Top-p: Once upon a time, there was a great deal of talk about the future of the game.

"I think we're going to have a great year," he said. "We're going to have a great year. We're
Temperature: Once upon a time, there was a face-saving player ploy executed politically against Latin American pop factories terrified to expand their American impact Indonesia fuelled havocess ones pockets, increased DH Anne Universal Washington translates resin to to undergo well designed ads laced with


The issue with greedy is that it tends to get stuck in a loop, for instance I got:

> Greedy: Once upon a time, there was a man who was a man of great wealth and power. He was a man of great wealth and power. He was a man of great wealth and power. He was a man of great wealth and power

If your top k is too restrictive (low) you end up haveing very minimal variety (notice that we set it to 5) so we end up with a lot of repitition of ideas and sometimes it gets stuck in a loop:

> Top-k: Once upon a time, there was a certain amount of excitement. It was like the moment you're going to get a new car, you're going to have an opportunity to see the car. And you're going to be able to see

If your top p is too low you get the same problem as with top k above.

> Top-p: Once upon a time, there was a man who was a member of the Church of England, and who had been a member of the Church of England for a long time. He was a man of great faith, and of great integrity.

Since a high temprature flattens the distribution, it tends to say things that make less sense together (since unlikely tokens are more likely to be sampled) for example I got the following: 

> Temperature: Once upon a time, there was a dark delicious pit held pumpkin still in Judaism, giving decorations in a royal participation one service hero path. Meanwhile unleashed shrines of even examination demons and vexes turned diabetes addicts restless vulnerable instead of officially beautiful


In [10]:
# often times you will see temprature and top p or top k combined so that we remove all unlikely next tokens and
# make some of the somewhat likely tokens more likely to be sampled
# try playing around with the temprature and p and k and see how good of an output you can get!

# Generate n tokens using different sampling strategies
n_tokens = 40

# Prepare input
text = "Once upon a time, there was a"
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)

p = 0.8
k = 20
temperature = 1.5


def temp_top_k(x):
    return sample_from_logits(
        temperature_sampling(top_k_sampling(x, k=k), temperature=temperature)
    )


def temp_top_p(x):
    return sample_from_logits(
        temperature_sampling(top_p_sampling(x, p=p), temperature=temperature)
    )


temp_top_p_output = generate_n_tokens(input_ids, n_tokens, temp_top_p)
temp_top_k_output = generate_n_tokens(input_ids, n_tokens, temp_top_k)

# Decode outputs
print(
    "Temperature and Top-k:",
    tokenizer.decode(temp_top_k_output[0], clean_up_tokenization_spaces=True),
)
print(
    "Temperature and Top-p:",
    tokenizer.decode(temp_top_p_output[0], clean_up_tokenization_spaces=True),
)

Temperature and Top-k: Once upon a time, there was a woman whom she thought was not a man: the beautiful lady, the lovely one. The other of these men was not one but a woman. I was very proud to know who this lady was and
Temperature and Top-p: Once upon a time, there was a remnant of loyalty of Indian workers who would not follow Congress. For two reasons, they fear retaliation and bitterness.

The election has caused a lot of pressure on some, including high-ranking bureaucrats
