<a href="https://colab.research.google.com/github/yonikremer/final_project/blob/master/project_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo - this is a demo for the project - 

Here yu can try out the sampling method I intreduced in my research.
I highly reccomend to lookup [my research project](https://github.com/yonikremer/final_project) for more information.

# Set Up

Please run the following cell once, than restart the kernal and run the other cells by order

In [1]:
%pip install -q transformers;
%pip install -q torch;

[K     |████████████████████████████████| 4.7 MB 29.3 MB/s 
[K     |████████████████████████████████| 6.6 MB 55.5 MB/s 
[K     |████████████████████████████████| 120 kB 71.7 MB/s 
[?25h

In [1]:
from typing import List, Tuple, Iterable, Set, Dict, Optional
from functools import reduce
import random
from copy import deepcopy

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedModel, BatchEncoding

import torch
from torch import cuda, device, LongTensor
from torch.nn import Softmax

In [3]:
if not cuda.is_available():
    print("Warning: CUDA not available")
    print("Running on CPU wil be very slow")
    print("it's highly recommended to use colab's GPU runtime")
    exit()
else:
    cuda = device('cuda')

# Important Notes:


*   Please make sure the model you choose is not too big for your hardware
*   gpt2 is the smallest version of gpt-2, gpt2-xl is the largest
*   bloom is the largest vesion of bloom



# Selecting a Model

In [4]:
#@title Select a model from the list

model_name = "bigscience/bloom-560m" #@param ["facebook/opt-125m", "facebook/opt-1.3b", "facebook/opt-350m", "gpt2", "gpt2-medium", "gpt2-large", "bigscience/bloom-1b7", "bigscience/bloom-1b1", "bigscience/bloom-560m"]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name);
model = AutoModelForCausalLM.from_pretrained(model_name).cuda();
vocab_size = AutoConfig.from_pretrained(model_name).vocab_size;

Downloading tokenizer_config.json:   0%|          | 0.00/222 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/13.8M [00:00<?, ?B/s]

Downloading special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/710 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

In [6]:
random.seed(0)

In [7]:
def print_bases(obj):
    curr_base = type(obj)
    while curr_base != object:
        print(curr_base)
        curr_base = curr_base.__base__

print_bases(model)

<class 'transformers.models.bloom.modeling_bloom.BloomForCausalLM'>
<class 'transformers.models.bloom.modeling_bloom.BloomPreTrainedModel'>
<class 'transformers.modeling_utils.PreTrainedModel'>
<class 'torch.nn.modules.module.Module'>


# The code I use to sample:

## Get probability matrix

In [12]:
def get_prob_mat(prompt: Optional[str], group_size: int, token_list: Optional[List[int]] = None, temp: float = 1.0):
    """Returns the probability matrix as a list of lists of floats"""
    if not isinstance(model, PreTrainedModel): 
        raise NotImplementedError("Only AutoModelForCausalLM is supported")

    if token_list is None:
        inputs = tokenizer(prompt, return_tensors="pt")
    else:
        attention_mask = torch.ones([1, len(token_list)])
        inputs = {"input_ids": LongTensor([token_list]), "attention_mask": attention_mask}
    inputs = {name: tensor.cuda() for name, tensor in inputs.items()}

    if not isinstance(inputs, dict):
        logits_pt_tensor = model(**inputs).logits.squeeze(0) / temp
    elif not "input_ids" in inputs.keys():
        logits_pt_tensor = model(**inputs).logits.squeeze(0) / temp
    else:
        logits_pt_tensor = model(**inputs, labels=inputs["input_ids"]).logits.squeeze(0)
    
    prob_tensor = Softmax(dim=1)(logits_pt_tensor)
    if group_size <= prob_tensor.shape[0]:
        prob_tensor = prob_tensor[-group_size:, :]
        prob_mat = [prob_tensor[i, :].tolist() for i in range(group_size)]
    else:
        print("Warning: the group size is bigger than the length of the model's output")
        print("If the length of the model input (in tokens) is n, n will be length of the model's output")
        print(f"the predicted text will be {group_size - prob_tensor.shape[0]} tokens shorter")
        prob_mat = [prob_tensor[i, :].tolist() for i in range(prob_tensor.shape[0])]
      
    return prob_mat

## distribution sampling  

In [13]:
def grouped_top_p_sampling(prob_mat: List[List[float]], top_p: float, org_used_tokens: List[int]):
    used_tokens = deepcopy(org_used_tokens)
    answer = []
    for curr_token_prob_list in prob_mat:
        for used_token in used_tokens:
            curr_token_prob_list[used_token] = 0.0

        indexed_prob: Dict[int, float] = {i: prob for i, prob in enumerate(curr_token_prob_list)}  # O(vocab_size)
        sorted_items = sorted(indexed_prob.items(), key = lambda item: item[1], reverse=True)
        sorted_indexed_prob = {key: value for key, value in sorted_items}
        top_p_indexed_prob: Dict[int, float] = {}
        prob_sum: float = 0.0
        for i, (curr_token, curr_prob) in enumerate(sorted_indexed_prob.items()):
            if i > 0 and prob_sum + curr_prob > top_p:
                break
            prob_sum += curr_prob
            top_p_indexed_prob[curr_token] = curr_prob

        weighted_probs = {key: value / prob_sum for key, value in top_p_indexed_prob.items()}

        sampled_token: int = random.choices(list(weighted_probs.keys()), weights = weighted_probs.values(), k=1)[0]
        answer.append(sampled_token)
        used_tokens.append(sampled_token)
    return answer


def grouped_top_k_sampling(prob_mat: List[List[float]], top_k: int, org_used_tokens: List[int]):
    used_tokens = deepcopy(org_used_tokens)
    answer = []
    for curr_token_prob_list in prob_mat:
        for used_token in used_tokens:
            curr_token_prob_list[used_token] = 0.0

        indexed_prob: Dict[int, float] = {i: prob for i, prob in enumerate(curr_token_prob_list)}  # O(vocab_size)
        sorted_items = sorted(indexed_prob.items(), key = lambda item: item[1], reverse=True)
        sorted_indexed_prob = {key: value for key, value in sorted_items}
        sorted_top_k_keys = list(sorted_indexed_prob.keys())[:top_k]
        top_k_indexed_prob: Dict[int, float] = {key: sorted_indexed_prob[key] for key in sorted_top_k_keys}
        prob_sum: float = sum(top_k_indexed_prob.values())
        weighted_probs = {key: value / prob_sum for key, value in top_k_indexed_prob.items()}

        sampled_token: int = random.choices(list(weighted_probs.keys()), weights = weighted_probs.values(), k=1)[0]
        answer.append(sampled_token)
        used_tokens.append(sampled_token)
    return answer


def generate_distribution(prompt, top_p, top_k, num_groups, group_size, temp: float = 1.0):
    already_used_tokens: List[int] = []
    tokenized_prompt_ten = tokenizer(prompt, return_tensors="pt")
    if isinstance(tokenized_prompt_ten, dict) or isinstance(tokenized_prompt_ten, BatchEncoding):
        if "input_ids" in tokenized_prompt_ten.keys():
            tokenized_prompt_ten = tokenized_prompt_ten["input_ids"]
    
    curr_token_list = tokenized_prompt_ten.squeeze().tolist()
    for _ in range(num_groups):
        prob_mat = get_prob_mat(None, group_size, curr_token_list, temp)
        if top_p is None and top_k is not None:
            new_tokens = grouped_top_k_sampling(prob_mat, top_k, curr_token_list)
        elif top_k is None and top_p is not None:
            new_tokens = grouped_top_p_sampling(prob_mat, top_p, curr_token_list)
        else:
            raise ValueError("Either top_p ot top_k should be None, but not both")
        curr_token_list.extend(new_tokens)
    final_ans = tokenizer.decode(curr_token_list)
    return final_ans

In [20]:
generate_distribution("Translate english to french: \n english: I am the best swimmer in the world \n french:", 4, None, 4, 3)

'Translate english to french: \n english: I am the best swimmer in the world \n french: fr:m Jeereonsieur suisw\n\n leormA:oser'

## Brute force sampling

In [None]:
def combinations(mat):
    """Returns all the lists such that list[j] is in mat[j]
    complexity: prod([len(mat[i]) for i in range(len(mat))])"""
    if len(mat) == 1:
        return [[mat[0][i]] for i in range(len(mat[0]))]
    res = []
    for i in mat[0]:
        for j in combinations(mat[1:]):
            res.append([i] + j)
    filtered_res = list(filter(doesnt_have_duplicates, res))
    return filtered_res


def doesnt_have_duplicates(my_list):
    """Return if there isn't a repetition in the list
    complexity: O(n) where n is the length of the list"""
    return len(my_list) == len(set(my_list))


def seq_prob(tokens, prob_mat, org_prompt_prob):
    """Given the probability matrix and a list of tokens
    returns the probability of the sequence
    prob_mat[a][b] is the probability of the token with id b the a-th token in the sequence"""
    probability = org_prompt_prob
    sequence_length = len(tokens)
    for i in range(sequence_length):
        curr_token = tokens[i]
        probability *= prob_mat[i][curr_token]
    return probability


def grouped_sampling(prob_mat: List[List[float]], top_p, top_k) -> List[List[int]]:
    """given a matrix of probabilities, returns a list of lists of tokens
    the matrixs is of size group_size x vocab_size
    where matrix[i, j] is the probability of token j the i-th token in the group
    samples the tokens such that for each place in the group,
    at most top_k tokens are sampled and at least one token is sampled
    and the added probability of all the tokens is less than or equal top_p
    returns a list of where every item is a tuple of a sequense and probability
    over all complexity of the function in O(group_size * vocab_size * log(vocab_size))"""
    
    # prob_tensor.shape is now (group_size, vocab_size)
    posible_tokens = []
    already_predicted = set()
    for token_prob in prob_mat: # group_size times
        vocab_size = len(token_prob)  # O(1)
        indexed_prob = list(zip(token_prob, range(vocab_size)))  # O(vocab_size)
        sorted_indexed_prob = sorted(indexed_prob, key=lambda x: x[0], reverse=True) # O(vocab_size*log(vocab_size))
        curr_k = 0
        total_prob = 0
        curr_indices = []
        for prob, token in sorted_indexed_prob:  # O(top_k)
            if total_prob + prob > top_p or curr_k == top_k:
                break
            if not token in already_predicted:
                already_predicted.add(token)
                curr_k += 1
                total_prob += prob
                curr_indices.append(token)
        posible_tokens.append(curr_indices)  # O(1)
    new_sequences: List[List[int]] = combinations(posible_tokens)  # theta(prod(len(indices[i]) for i in range(group_size)))
    # len(indices[i]) < min(top_k, vocab_size)
    # therefore the complexity is O(min(top_k, vocab_size) * group_size)
    return new_sequences


def flatten(l: list) -> List:
    """Gets a list where some of the elements might be lists
    and adds every item in the inner list to the outer list.
    example: [1, [2, 3], 4, [[5]]] -> [1, 2, 3, 4, 5]
    Complexity: O(len(the flatten list) + the number of diffrent lists))"""
    new_list = []
    for item in l:
        if isinstance(item, list):
            new_list.extend(flatten(item))
        else:
            new_list.append(item)
    return new_list


def generate_beam(org_prompt, top_p, top_k, num_groups, group_size, org_prompt_prob = 1.0, temp = 1.0) -> Dict[Tuple[int], float]:
    """preprocess the prompt and completes the text
    model name: str
    org_prompt: str
    top_p: float (0.0 - 1.0)
    top_k: int > 0
    num_groups: int > 0
    group_size: int > 0
    org_prompt_prob: float <= 1
    """
    if isinstance(org_prompt, str):
        str_prompt = org_prompt
        try:
            tokenized_prompt_ten = tokenizer(str_prompt, return_tensors="pt")["input_ids"]
        except Exception:
            tokenized_prompt_ten = tokenizer(str_prompt, return_tensors="pt")
        tokenized_prompt_list = tokenized_prompt_ten.tolist()
    elif isinstance(org_prompt, list) or isinstance(org_prompt, tuple):
        tokenized_prompt_list = flatten(org_prompt)
        str_prompt = tokenizer.decode(tokenized_prompt_list)
    else:
        print(org_prompt)
        raise ValueError("org_prompt must be a string or list of integers")
    
    prob_mat = get_prob_mat(str_prompt, group_size, temp)
    tokenized_ans_list = grouped_sampling(prob_mat, top_p, top_k)
    prob_list: List[Tuple[List[int], float]] = [seq_prob(seq, prob_mat, org_prompt_prob) for seq in tokenized_ans_list]
    new_prompts: List[List[int]] = [flatten(tokenized_prompt_list + ans) for ans in tokenized_ans_list]
    complition_prob_dict = remove_duplicates(new_prompts, prob_list)
    if num_groups == 1:
        return complition_prob_dict
    new_complitions: Dict[Tuple[int], float] = dict()
    for curr_new_prompt, curr_new_prompt_prob in complition_prob_dict.items():
        curr_completions: Dict[Tuple[int], float] = generate_beam(model_name, curr_new_prompt, top_p, top_k, num_groups - 1, group_size, curr_new_prompt_prob)
        tokens: Tuple[int]
        prob: float
        for tokens, prob in curr_completions.items():
            new_complitions[tokens] = prob
    return new_complitions


def remove_duplicates(completions: List[List[int]], probs: List[float]) -> Dict[Tuple[int], float]:
    """Given a list of tokenized answers and the probability of each complition,
    removes every repeated completion and every complition that have repeated tokens"""
    filtered_completions: Dict[Tuple[int], float] = dict()
    for curr_comp, curr_prob in zip(completions, probs):
        try:
            cond = len(curr_comp) == len(set(curr_comp))
        except TypeError as e:
            print(curr_comp)
            raise e
        if cond:
            curr_comp_tuple = tuple(curr_comp)
            filtered_completions[curr_comp_tuple] = curr_prob
    return filtered_completions


# Use grouped sampling:

## How to use previous methods:

| brute force         | nucleus (top p)                 | top k                              | greedy         |
|---------------------|---------------------------------|------------------------------------|----------------|
| top p=0             | 0 < top p < 1                   | top p = 1                          | top p = 0      |
| top k >= vocab_size | top k >= the model's vocab size | 1 < top k < the model's vocab size | top k = 1      |
| group_size = 1      | group_size = 1                  | group_size = 1                     | group_size = 1 |

In [1]:
prompt = "def logger(original_function: callable, file_name: str): \\n \"\"\"A decerator that logs the argument and key word arguments of  original_function to file_name.log\"\"\"\\n" #@param {type:"string"}
top_p = 1 #@param {type:"slider", min:0.05, max:1.0, step:0.05}
top_k = 5 #@param {type:"integer"}
group_size = 8 #@param {type:"integer"}
num_groups = 2 #@param {type:"integer"}
temperature = 0 #@param {type:"number", min:0.000000001}
sampling_type = 'from distribution' #@param ["from distribution", "brute force"]


actual_top_k = min(top_k, vocab_size * top_p)
num_tokens_genrated = num_groups * group_size

if sampling_type == "brute force":
    max_num_calls = sum((actual_top_k ** i) for i in range(num_groups))    
    pervious_methods_max_num_calls = sum((actual_top_k ** i) for i in range(num_tokens_genrated))
else:
    max_num_calls = num_groups
    pervious_methods_max_num_calls = num_tokens_genrated

print(f"The model will generate {num_tokens_genrated} tokens (words or parts of words)")
print(f"It will call the model at most {max_num_calls} times to the model")
print(f"Previous methods will need up to {pervious_methods_max_num_calls} call to the model to genrate the same text")

NameError: ignored

In [None]:
if sampling_type == "brute_force":
    start_grouped_generation = timeit.default_timer()
    grouped_completions = generate_beam(prompt, top_p, top_k, num_groups, group_size, temperature)
    stop_grouped_generation = timeit.default_timer()

    start_old_generation = timeit.default_timer()
    old_completions = generate_beam(prompt, top_p, top_k, num_tokens_genrated, 1, temperature)
    stop_old_generation = timeit.default_timer()

    if len(grouped_completions.keys()) > 0 and len(old_completions.keys()) > 0:
        grouped_best_token_list = max(grouped_completions, key=grouped_completions.get)
        old_best_token_list = max(old_completions, key=old_completions.get)
    else:
        raise ValueError("""
            The model could not generate a complition with the constrains of the sampling method, U should try:
            1) Increase top p and/or top k
            2) Use diffrent (larger) model
            3) Decrease the number of groups
            4) Decrease the group size
            """)
else:
    start_grouped_generation = timeit.default_timer()
    if 0 < top_p < 1:
        grouped_best_token_list = generate_distribution(model_name, prompt, top_p, None, num_groups, group_size, temperature)
    else:
        grouped_best_token_list = generate_distribution(model_name, prompt, None, top_k, num_groups, group_size, temperature)
    stop_grouped_generation = timeit.default_timer()

    start_old_generation = timeit.default_timer()
    if 0 < top_p < 1:
        old_best_token_list = generate_distribution(model_name, prompt, top_p, None, num_tokens_genrated, 1)
    else:
        old_best_token_list = generate_distribution(model_name, prompt, None, top_k, num_tokens_genrated, 1)
    stop_old_generation = timeit.default_timer()

    
grouped_best_str = tokenizer.decode(grouped_best_token_list)
old_best_str = tokenizer.decode(old_best_token_list)

print("The best text generated by grouped sampling is:")
print(grouped_best_str)
print("The best text generated by previous methods sampling is:")
print(old_best_str)

If the length of the model input (in tokens) is n, n will be length of the model's output
the predicted text will be 1 tokens shorter


KeyboardInterrupt: ignored

In [None]:
if sampling_type == "brute_force":
    sorted_complitions = sorted(completions, key = completions.get)
    print("Here are all the complitions created by predicted probability")
    for tokens in sorted_complitions:
        print(tokenizer.decode(tokens))

# Change the hyper parameters to see what will happend!

In [9]:
import numpy as np
def create_look_ahead_mask(seq_len: int):
    answer = np.zeros(shape=[seq_len, seq_len])
    for i in range(seq_len):
        for j in range(seq_len):
            if j > i:
                answer[i, j] = 1
    return answer
create_look_ahead_mask(4)

array([[0., 1., 1., 1.],
       [0., 0., 1., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 0.]])