In [None]:
tokenizer.token_to_id('[SOS]')


import torch
import torch.nn.functional as F
device = 'cuda'
# Define the top-k sampling function
def top_k_sampling(logits: torch.Tensor, k: int, temperature: float = 1.0) -> int:
    """
    Perform top-k sampling to select the next token.

    Args:
        logits (torch.Tensor): The logits (unnormalized predictions) from the model of shape (vocab_size,).
        k (int): The number of top tokens to consider for sampling.
        temperature (float): The temperature value to scale logits. Higher temperature results in more randomness.

    Returns:
        int: The index of the selected token.
    """
    # Scale logits by temperature
    scaled_logits = logits / temperature
    
    # Get the top-k logits and their indices
    top_k_logits, top_k_indices = torch.topk(scaled_logits, k)
    
    # Re-normalize the top-k logits to get probabilities
    probabilities = F.softmax(top_k_logits, dim=-1)
    
    # Sample from the top-k probabilities
    next_token = torch.multinomial(probabilities, num_samples=1)
    
    # Convert top-k index to original vocabulary index
    print(top_k_indices[next_token])
    return top_k_indices[next_token].item()

# Sample input text
input_text = "Often times I find myself thinking scary thoughts and sometimes I even scare myself into thinking that something bad is going to happen to me. Once it starts, the thought continues going through my head and I can't get it out. How can I stop these thoughts?"

# Assuming tokenizer is already defined and initialized
token_ids = tokenizer.encode(input_text).ids

# Define special tokens and padding length
MAX_SRC_LEN = 200  # Assuming a maximum source length of 200 for this example
MAX_TGT_LEN = 100  # Assuming a maximum target length of 100 for the generated sequence

sos_idx = tokenizer.token_to_id('[SOS]')
eos_idx = tokenizer.token_to_id('[EOS]')
pad_idx = tokenizer.token_to_id('[PAD]')

# Compute the number of padding tokens required
enc_num_pad = MAX_SRC_LEN - len(token_ids) - 2

# Prepare the encoder input
encoder_input = torch.cat(
    [
        torch.tensor([sos_idx], dtype=torch.int64),
        torch.tensor(token_ids, dtype=torch.int64),
        torch.tensor([eos_idx], dtype=torch.int64),
        torch.tensor([pad_idx] * enc_num_pad, dtype=torch.int64)
    ], dim=0
).to(device)

# Create the encoder mask
encoder_mask = (encoder_input != pad_idx).unsqueeze(0).unsqueeze(0).int().to(device)  # (1, 1, MAX_SRC_LEN)

# Assuming transformer is already defined and initialized
transformer.eval()
with torch.no_grad():
    # Encode the input sequence
    encoder_output = transformer.encode(encoder_input, encoder_mask)
    
    # Initialize the target sequence with the start token
    tgt = torch.ones(1, 1).fill_(sos_idx).type_as(encoder_input)
    
    for i in range(1, MAX_TGT_LEN):
        # Create the decoder mask for the current length of the target sequence
        deco_mask = decoder_mask(tgt.size(1)).type_as(encoder_mask).to(device)
        
        # Decode the target sequence
        out = transformer.decode(encoder_output, encoder_mask, tgt, deco_mask)
        
        # Get the logits for the last token in the sequence
        logits = out[:, -1, :]
        
        # Apply top-k sampling to select the next token
        next_token = top_k_sampling(logits, k=5, temperature=1.0)  # Adjust k and temperature as needed
        
        # Append the next token to the target sequence
        tgt = torch.cat([tgt, torch.ones(1, 1).type_as(encoder_input).fill_(next_token)], dim=1)
        
        # Break if the end of sequence token is generated
        if next_token == eos_idx:
            break

# The generated sequence of token IDs
print("Generated sequence of token IDs:", tgt)

# Decode the generated token IDs to text
generated_text = tokenizer.decode(tgt[0].cpu().numpy().tolist(), skip_special_tokens=True)
print("Generated text:", generated_text)


In [None]:
input_text = "Often times I find myself thinking scary thoughts and sometimes I even scare myself into thinking that something bad is going to happen to me. Once it starts, the thought continues going through my head and I can't get it out. How can I stop these thoughts?"

token_ids = tokenizer.encode(input_text).ids

enc_num_pad = MAX_SRC_LEN - len(token_ids) - 2

sos_idx = tokenizer.token_to_id('[SOS]')
eos_idx = tokenizer.token_to_id('[EOS]')
pad_idx = tokenizer.token_to_id('[PAD]')

encoder_input = torch.cat(
        [
            torch.tensor([sos_idx], dtype = torch.int64),
            torch.tensor(token_ids, dtype=torch.int64),
            torch.tensor([eos_idx], dtype = torch.int64),
            torch.tensor([pad_idx] * enc_num_pad, dtype = torch.int64)
        ], dim = 0
    ).to(device)


encoder_mask = (encoder_input != pad_idx).unsqueeze(0).unsqueeze(0).int().to(device) # (1, 1, 200)


transformer.eval()
with torch.no_grad():
    # 1) encoder_output
    encoder_output = transformer.encode(encoder_input, encoder_mask)
    tgt = torch.ones(1, 1).fill_(sos_idx).type_as(encoder_input)
    for i in range(1, MAX_TGT_LEN):
        deco_mask = decoder_mask(tgt.size(1)).type_as(encoder_mask).to(device)
        out = transformer.decode(encoder_output, encoder_mask, tgt, deco_mask)
        out_token = torch.argmax(out)
        
        tgt = torch.cat([tgt, torch.ones(1, 1).type_as(encoder_mask).fill_(out_token)], dim=1)
#         print(torch.ones(1, 1).type_as(encoder_mask).fill_(out_token))
#         print('out', out)
#         print('deco mask', deco_mask)
        
        

In [None]:
tokenizer.decode(tgt.tolist(a))

In [None]:
out.shape

In [None]:
torch.argmax(out)

In [None]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)
