# Implementacion de la Sampling Nucleus

In [2]:
import torch
import random
import numpy as np
import torch.nn.functional as F

random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

In [19]:
N = 3  # batch size
V = 10  # vocab size
vocab = [chr(i) for i in range(97, 97+V)]
print("vocab", vocab)

p = 0.9  # threshold
m = p  # redistribute nucleus by

# logits for a single decoding step
logits = torch.randn((N, V))
probs = F.softmax(logits, dim=-1)
print("probs", probs)

vocab ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
probs tensor([[0.0648, 0.1163, 0.1452, 0.0747, 0.0461, 0.0364, 0.4072, 0.0279, 0.0247,
         0.0569],
        [0.0151, 0.0926, 0.0621, 0.1482, 0.0767, 0.1408, 0.0872, 0.0694, 0.2323,
         0.0756],
        [0.2202, 0.0785, 0.0172, 0.3600, 0.0431, 0.1062, 0.0430, 0.0454, 0.0774,
         0.0089]])


In [24]:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)

In [28]:
sorted_indices

tensor([[6, 2, 1, 3, 0, 9, 4, 5, 7, 8],
        [8, 3, 5, 1, 6, 4, 9, 7, 2, 0],
        [3, 0, 5, 1, 8, 7, 4, 6, 2, 9]])

In [26]:
cumulative_sum = torch.cumsum(sorted_probs, dim=-1)
cumulative_sum

tensor([[0.4072, 0.5523, 0.6686, 0.7433, 0.8081, 0.8650, 0.9111, 0.9474, 0.9753,
         1.0000],
        [0.2323, 0.3805, 0.5213, 0.6139, 0.7011, 0.7778, 0.8534, 0.9228, 0.9849,
         1.0000],
        [0.3600, 0.5802, 0.6864, 0.7650, 0.8424, 0.8878, 0.9309, 0.9739, 0.9911,
         1.0000]])

In [40]:
out_of_nucleus = cumulative_sum > p
out_of_nucleus

tensor([[False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True]])

In [41]:
# cumulative_sum > p comparison always misses the last token that should be in the nucleus
# this line fixes that
out_of_nucleus[:, 1:] = out_of_nucleus[:, :-1].clone()
out_of_nucleus[:, 0] = False
out_of_nucleus

tensor([[False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True]])

In [42]:
# re-scale the distribution according to eqn 3 of the paper
print(sorted_probs)
sorted_probs[out_of_nucleus] = 0
print(sorted_probs)

tensor([[0.4072, 0.1452, 0.1163, 0.0747, 0.0648, 0.0569, 0.0461, 0.0364, 0.0279,
         0.0247],
        [0.2323, 0.1482, 0.1408, 0.0926, 0.0872, 0.0767, 0.0756, 0.0694, 0.0621,
         0.0151],
        [0.3600, 0.2202, 0.1062, 0.0785, 0.0774, 0.0454, 0.0431, 0.0430, 0.0172,
         0.0089]])
tensor([[0.4072, 0.1452, 0.1163, 0.0747, 0.0648, 0.0569, 0.0461, 0.0000, 0.0000,
         0.0000],
        [0.2323, 0.1482, 0.1408, 0.0926, 0.0872, 0.0767, 0.0756, 0.0694, 0.0000,
         0.0000],
        [0.3600, 0.2202, 0.1062, 0.0785, 0.0774, 0.0454, 0.0431, 0.0000, 0.0000,
         0.0000]])


In [44]:
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1).unsqueeze(1)
print(sorted_probs)

tensor([[0.4469, 0.1593, 0.1276, 0.0820, 0.0712, 0.0624, 0.0506, 0.0000, 0.0000,
         0.0000],
        [0.2518, 0.1606, 0.1525, 0.1003, 0.0945, 0.0831, 0.0820, 0.0752, 0.0000,
         0.0000],
        [0.3867, 0.2366, 0.1141, 0.0844, 0.0832, 0.0488, 0.0463, 0.0000, 0.0000,
         0.0000]])


In [52]:
sorted_selected_indices = torch.multinomial(sorted_probs, 1)
sorted_selected_indices

tensor([[3],
        [0],
        [1]])

In [54]:
token_probs = torch.gather(sorted_probs, dim=-1, index=sorted_selected_indices)
token_probs

tensor([[0.0820],
        [0.2518],
        [0.2366]])

In [55]:
token_indices = torch.gather(sorted_indices, dim=-1, index=sorted_selected_indices)
token_indices

tensor([[3],
        [8],
        [0]])

In [80]:
token_indices.flatten().tolist(), token_probs.flatten().tolist()

([3, 8, 0], [0.08198275417089462, 0.2517637014389038, 0.23660142719745636])

In [99]:
def nucleus_sample(logits: torch.Tensor, p: float, generator):
    assert logits.dim() == 2, "expected a matrix (batch, vocab_size)"
    
    probs = F.softmax(logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_sum = torch.cumsum(sorted_probs, dim=-1)
    out_of_nucleus = cumulative_sum > p
    # cumulative_sum > p comparison always misses the last token that should be in the nucleus
    # this line fixes that
    out_of_nucleus[:, 1:] = out_of_nucleus[:, :-1].clone()
    out_of_nucleus[:, 0] = False    
    sorted_probs[out_of_nucleus] = 0
    # Eq. 3 from the nucleus paper
    sorted_probs = sorted_probs / sorted_probs.sum(dim=-1).unsqueeze(1)
    sorted_selected_indices = torch.multinomial(sorted_probs, 1, generator=generator)
    token_probs = torch.gather(sorted_probs, dim=-1, index=sorted_selected_indices)
    token_indices = torch.gather(sorted_indices, dim=-1, index=sorted_selected_indices)
    
    return token_indices.flatten().tolist(), token_probs.flatten().tolist()

# LabML Implementation

In [100]:
class NucleusSampler:
    """
    ## Nucleus Sampler
    """

    def __init__(self, p: float, generator):
        """
        :param p: is the sum of probabilities of tokens to pick $p$
        :param sampler: is the sampler to use for the selected tokens
        """
        self.p = p
        self.generator = generator
        # Softmax to compute $P(x_i | x_{1:i-1})$ from the logits
        self.softmax = nn.Softmax(dim=-1)

    def __call__(self, logits: torch.Tensor):
        """
        Sample from logits with Nucleus Sampling
        """

        # Get probabilities $P(x_i | x_{1:i-1})$
        probs = self.softmax(logits)

        # Sort probabilities in descending order
        sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
        # Get the cumulative sum of probabilities in the sorted order
        cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
        # Find the cumulative sums less than $p$.
        nucleus = cum_sum_probs < self.p
        # Prepend ones so that we add one token after the minimum number
        # of tokens with cumulative probability less that $p$.
        nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)

        # Get log probabilities and mask out the non-nucleus
        # sorted_log_probs = torch.log(sorted_probs)
        sorted_probs[~nucleus] = 0.

        # Sample from the sampler
        sampled_sorted_indexes = torch.multinomial(sorted_probs, 1, generator=self.generator)

        # Get the actual indexes
        res = indices.gather(-1, sampled_sorted_indexes)

        #
        return res.squeeze(-1)

In [109]:
g_cpu = torch.Generator()
g_cpu = g_cpu.manual_seed(0)

In [140]:
from torch import nn
nucleus_sample_labml = NucleusSampler(p, generator=g_cpu)
print(nucleus_sample_labml(logits))
print(nucleus_sample(logits, p, generator=g_cpu))

tensor([3, 3, 4])
([1, 3, 0], [0.12760135531425476, 0.1605621576309204, 0.23660142719745636])
