In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import functional as F
import random


model_name = 'eryk-mazus/polka-1.1b'
device = 'cuda'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [2]:
tokenizer.encode("ab")

[1, 633]

In [4]:
vocab = tokenizer.get_vocab()

In [6]:
vocab["▁Nau"]

26700

In [57]:
def restricted_sampling(input_ids, predicate):
    # Generate logits for the next token
    with torch.no_grad():
        # input_ids = tokenizer(sentence_txt, return_tensors='pt')['input_ids'].to(device)
        outputs = model(torch.Tensor([input_ids]).long().to(device))
        next_token_logits = outputs.logits[:, -1, :]  # Get logits for the last token

    # Convert logits to probabilities
    probs = F.softmax(next_token_logits, dim=-1).squeeze()

    # Get tokens that start with "p" or "P" and set their probability to 0
    for token_id in range(len(probs)):
        token = tokenizer.convert_ids_to_tokens(token_id)

        if not token or not predicate(token):
            probs[token_id] = 0

    # Renormalize probabilities
    probs /= probs.sum()

    # Sample from the filtered distribution
    filtered_token_id = torch.multinomial(probs, num_samples=1).item()
    print(tokenizer.convert_ids_to_tokens(filtered_token_id))
    # return tokenizer.decode([filtered_token_id])
    return filtered_token_id

def extend_sentence(sentence_txt, iter=10):
    letter = sentence_txt[0]
    token_chain = list(tokenizer.encode(sentence_txt))

    for i in range(iter):
        # token_chain.append(restricted_sampling(token_chain, lambda token: token.startswith(letter)))
        if tokenizer.decode(token_chain)[-1] == " ":
            # print(tokenizer.decode(token_chain))
            token_chain.append(restricted_sampling(token_chain, lambda token: token.startswith(letter)))
        else:
        	token_chain.append(restricted_sampling(token_chain, lambda token: token.startswith("▁" + letter) or not token.startswith("▁")))

        # if tokenizer.convert_ids_to_tokens(token_chain[-1])[-1] == "▁":
        # 	token_chain.append(restricted_sampling(token_chain, lambda token: token.startswith(letter)))
        # else:
        # 	token_chain.append(restricted_sampling(token_chain, lambda token: token.startswith("▁" + letter) or not token.startswith("▁")))
            
    return token_chain

In [58]:
sentence = extend_sentence("Ala ma kota")

,▁z
iom
la
,
gu
zach
owego
,
Syn
u


In [59]:
tokenizer.decode(sentence)

'<s> Ala ma kota, ziomla,guzachowego,Synu'

In [31]:
tokenizer.decode(tokenizer.encode("Ala ma kota"), skip_special_tokens=True)

'Ala ma kota'