<a href="https://colab.research.google.com/github/yonikremer/final_project/blob/master/project_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo - this is a demo for the project - 

Here yu can try out the sampling method I intreduced in my research.
I highly reccomend to lookup [my research project](https://github.com/yonikremer/final_project) for more information.

# Set Up

In [1]:
%pip install -q transformers;
%pip install -q torch;

[K     |████████████████████████████████| 4.7 MB 22.8 MB/s 
[K     |████████████████████████████████| 6.6 MB 57.4 MB/s 
[K     |████████████████████████████████| 120 kB 72.9 MB/s 
[?25h

In [1]:
from typing import List, Tuple, Dict, Optional, Union
import random
from copy import deepcopy
import timeit
from abc import ABC, abstractmethod
from collections.abc import Callable
from math import ceil

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BatchEncoding

import torch
from torch import cuda, device, LongTensor, no_grad
from torch.nn import Softmax

In [3]:
if not cuda.is_available():
    print("Warning: CUDA not available")
    print("Running on CPU wil be very slow")
    print("it's highly recommended to use colab's GPU runtime")
    exit()
else:
    cuda = device('cuda')

# Important Notes:


*   Please make sure the model you choose is not too big for your hardware
*   gpt2 is the smallest version of gpt-2, gpt2-xl is the largest
*   bloom is the largest vesion of bloom



# The code I use to sample:

## Object Oriented way

In [4]:
class TextGenerator(Callable, ABC):
    """Generates text given a model, a prompt and some sampling parameters."""
    def __init__(self, model_name: str, group_size: int, temp: float = 1.0):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
        self.vocab_size = AutoConfig.from_pretrained(model_name).vocab_size
        self.temp = temp
        self.group_size = group_size
        pad_id = self.tokenizer.pad_token_id
        self.padding_tokens = [pad_id for _ in range(self.group_size - 1)]


    def get_prob_mat(self, prompt: Optional[str], token_list: Optional[List[int]] = None):
        """Returns the probability matrix as a list of lists of floats"""
        attention_len = len(token_list) + self.group_size - 1
        if token_list is None:
            tokenenized_prompt = self.tokenizer(prompt, return_tensors="pt")
            if isinstance(tokenenized_prompt, list):
                token_list = tokenenized_prompt
            else:
                token_list = tokenenized_prompt["input_ids"]

        longer_token_list = token_list + self.padding_tokens
        longer_token_tensor = LongTensor([longer_token_list])
        attention_mask = torch.ones([1, attention_len])
        inputs = {"input_ids": longer_token_tensor, "attention_mask": attention_mask}
        inputs = {name: tensor.cuda() for name, tensor in inputs.items()}

        if not isinstance(inputs, dict):
            with no_grad():
                logits_pt_tensor = self.model(**inputs).logits.squeeze(0) / self.temp
        elif not "input_ids" in inputs.keys():
            with no_grad():
                logits_pt_tensor = self.model(**inputs).logits.squeeze(0) / self.temp
        else:
            with no_grad():
                logits_pt_tensor = self.model(**inputs, labels=inputs["input_ids"]).logits.squeeze(0) / self.temp

        prob_tensor = Softmax(dim=1)(logits_pt_tensor)
        if self.group_size <= prob_tensor.shape[0]:
            prob_tensor = prob_tensor[-self.group_size:, :]
            prob_mat = [prob_tensor[i, :].tolist() for i in range(self.group_size)]
        else:
            print("Warning: the group size is bigger than the length of the model's output")
            print("If the length of the model input (in tokens) is n, n will be length of the model's output")
            print(f"the predicted text will be {self.group_size - prob_tensor.shape[0]} tokens shorter")
            prob_mat = [prob_tensor[i, :].tolist() for i in range(prob_tensor.shape[0])]

        return prob_mat


    @abstractmethod
    def __call__(self, prompt: str, num_new_tokens: int) -> str:
        pass

In [5]:
class SampleGen(TextGenerator):
    def __init__(self, model_name: str, group_size: int, top_k: int, top_p: float, temp: float = 1.0):
        super().__init__(model_name, group_size, temp)
        random.seed(0)
        if top_p is None and top_k is not None:
            self.top_k = top_k
            self.sampling_method = "top k"
        elif top_k is None and top_p is not None:
            self.top_p = top_p
            self.sampling_method = "top p"


    def grouped_top_p_sampling(self, prob_mat: List[List[float]], org_used_tokens: List[int]):
        used_tokens = deepcopy(org_used_tokens)
        answer = []
        for curr_token_prob_list in prob_mat:
            for used_token in used_tokens:
                curr_token_prob_list[used_token] = 0.0

            indexed_prob: Dict[int, float] = {i: prob for i, prob in enumerate(curr_token_prob_list)}  # O(vocab_size)
            sorted_items = sorted(indexed_prob.items(), key = lambda item: item[1], reverse=True)
            sorted_indexed_prob = {key: value for key, value in sorted_items}
            top_p_indexed_prob: Dict[int, float] = {}
            prob_sum: float = 0.0
            for i, (curr_token, curr_prob) in enumerate(sorted_indexed_prob.items()):
                if i > 0 and prob_sum + curr_prob > self.top_p:
                    break
                prob_sum += curr_prob
                top_p_indexed_prob[curr_token] = curr_prob

            weighted_probs = {key: value / prob_sum for key, value in top_p_indexed_prob.items()}

            sampled_token: int = random.choices(list(weighted_probs.keys()), weights = weighted_probs.values(), k=1)[0]
            answer.append(sampled_token)
            used_tokens.append(sampled_token)
        return answer


    def grouped_top_k_sampling(self, prob_mat: List[List[float]], org_used_tokens: List[int]):
        used_tokens = deepcopy(org_used_tokens)
        answer = []
        for curr_token_prob_list in prob_mat:
            for used_token in used_tokens:
                curr_token_prob_list[used_token] = 0.0

            indexed_prob: Dict[int, float] = {i: prob for i, prob in enumerate(curr_token_prob_list)}  # O(vocab_size)
            sorted_items = sorted(indexed_prob.items(), key = lambda item: item[1], reverse=True)
            sorted_indexed_prob = {key: value for key, value in sorted_items}
            sorted_top_k_keys = list(sorted_indexed_prob.keys())[:self.top_k]
            top_k_indexed_prob: Dict[int, float] = {key: sorted_indexed_prob[key] for key in sorted_top_k_keys}
            prob_sum: float = sum(top_k_indexed_prob.values())
            weighted_probs = {key: value / prob_sum for key, value in top_k_indexed_prob.items()}

            sampled_token: int = random.choices(list(weighted_probs.keys()), weights = weighted_probs.values(), k=1)[0]
            answer.append(sampled_token)
            used_tokens.append(sampled_token)
        return answer


    def __call__(self, prompt: str, num_new_tokens: int) -> str:
        num_groups = ceil(num_new_tokens / self.group_size)
        tokenized_prompt_ten = self.tokenizer(prompt, return_tensors="pt")
        if isinstance(tokenized_prompt_ten, dict) or isinstance(tokenized_prompt_ten, BatchEncoding):
            if "input_ids" in tokenized_prompt_ten.keys():
                tokenized_prompt_ten = tokenized_prompt_ten["input_ids"]

        curr_token_list = tokenized_prompt_ten.squeeze().tolist()
        for _ in range(num_groups):
            prob_mat = self.get_prob_mat(None, curr_token_list)
            if self.sampling_method == "top k":
                new_tokens = self.grouped_top_k_sampling(prob_mat, curr_token_list)
            elif self.sampling_method == "top p":
                new_tokens = self.grouped_top_p_sampling(prob_mat, curr_token_list)
            else:
                raise ValueError("Either top_p ot top_k should be None, but not both")
            curr_token_list.extend(new_tokens)
        final_num_tokens = tokenized_prompt_ten.shape[1] + num_new_tokens
        shorten_token_list = curr_token_list[:final_num_tokens]
        final_ans = self.tokenizer.decode(shorten_token_list)
        return final_ans

In [6]:
class TreeGen(SampleGen):
    def __init__(self, model_name: str, group_size: int, top_k: int, top_p: float, temp: float = 1.0):
        super().__init__(model_name, group_size, temp)
        self.top_k = top_k
        self.top_p = top_p


    @staticmethod
    def doesnt_have_duplicates(my_list):
        """Return if there isn't a repetition in the list
        complexity: O(n) where n is the length of the list"""
        return len(my_list) == len(set(my_list))


    @staticmethod
    def combinations(mat):
        """Returns all the lists such that list[j] is in mat[j]
        complexity: prod([len(mat[i]) for i in range(len(mat))])"""
        if len(mat) == 1:
            return [[mat[0][i]] for i in range(len(mat[0]))]
        res = []
        for i in mat[0]:
            for j in TreeGen.combinations(mat[1:]):
                res.append([i] + j)
        filtered_res = list(filter(TreeGen.doesnt_have_duplicates, res))
        return filtered_res


    @staticmethod
    def seq_prob(tokens, prob_mat, org_prompt_prob) -> float:
        """Given the probability matrix and a list of tokens
        returns the probability of the sequence
        prob_mat[a][b] is the probability of the token with id b the a-th token in the sequence"""
        probability = org_prompt_prob
        sequence_length = len(tokens)
        for i in range(sequence_length):
            curr_token = tokens[i]
            probability *= prob_mat[i][curr_token]
        return probability


    @staticmethod
    def flatten(l: Union[list, tuple]) -> List:
        """Gets a list where some elements might be lists
        and adds every item in the inner list to the outer list.
        example: [1, [2, 3], 4, [[5]]] -> [1, 2, 3, 4, 5]
        Complexity: O(len(the flatten list) + the number of diffrent lists))"""
        new_list = []
        for item in l:
            if isinstance(item, list):
                new_list.extend(TreeGen.flatten(item))
            else:
                new_list.append(item)
        return new_list


    @staticmethod
    def remove_duplicates(completions: List[List[int]], probs: List[float]) -> Dict[Tuple[int], float]:
        """Given a list of tokenized answers and the probability of each completion,
        removes every repeated completion and every completion that have repeated tokens"""
        filtered_completions: Dict[Tuple[int], float] = dict()
        for curr_comp, curr_prob in zip(completions, probs):
            try:
                cond = len(curr_comp) == len(set(curr_comp))
            except TypeError as e:
                print(curr_comp)
                raise e
            if cond:
                curr_comp_tuple = tuple(curr_comp)
                filtered_completions[curr_comp_tuple] = curr_prob
        return filtered_completions


    def tree_grouped_sampling(self, prob_mat: List[List[float]]) -> List[List[int]]:
        """given a matrix of probabilities, returns a list of lists of tokens
        the matrix's is of size group_size x vocab_size
        where matrix[i, j] is the probability of token j the i-th token in the group
        samples the tokens such that for each place in the group,
        at most top_k tokens are sampled and at least one token is sampled
        and the added probability of all the tokens is less than or equal top_p
        returns a list of where every item is a tuple of a sequense and probability
        over all complexity of the function in O(group_size * vocab_size * log(vocab_size))"""

        # prob_tensor.shape is now (group_size, vocab_size)
        possible_tokens = []
        already_predicted = set()
        for token_prob in prob_mat: # group_size times
            vocab_size = len(token_prob)  # O(1)
            indexed_prob = list(zip(token_prob, range(vocab_size)))  # O(vocab_size)
            sorted_indexed_prob = sorted(indexed_prob, key=lambda x: x[0], reverse=True) # O(vocab_size*log(vocab_size))
            curr_k = 0
            total_prob = 0
            curr_indices = []
            for prob, token in sorted_indexed_prob:  # O(top_k)
                if total_prob + prob > self.top_p or curr_k == self.top_k:
                    break
                if not token in already_predicted:
                    already_predicted.add(token)
                    curr_k += 1
                    total_prob += prob
                    curr_indices.append(token)
            possible_tokens.append(curr_indices)  # O(1)
        new_sequences: List[List[int]] = TreeGen.combinations(possible_tokens)
        # theta(prod(len(indices[i]) for i in range(group_size)))
        # len(indices[i]) < min(top_k, vocab_size)
        # therefore the complexity is O(min(top_k, vocab_size) * group_size)
        return new_sequences


    def rec_gen(self, org_prompt, num_tokens: int, org_prompt_prob: float = 1.0) -> Dict[Tuple[int], float]:
        """Recursively generates the next group of tokens in a tree like behavior"""
        num_groups = ceil(num_tokens / self.group_size)
        if isinstance(org_prompt, list) or isinstance(org_prompt, tuple):
            tokenized_prompt_list = TreeGen.flatten(org_prompt)
            str_prompt = self.tokenizer.decode(tokenized_prompt_list)
        else:
            print(org_prompt)
            raise ValueError("org_prompt must be a string or list of integers")

        prob_mat = TreeGen.get_prob_mat(str_prompt, group_size)
        tokenized_ans_list = TreeGen.tree_grouped_sampling(prob_mat)
        prob_list: List[float] = [TreeGen.seq_prob(seq, prob_mat, org_prompt_prob) for seq in tokenized_ans_list]
        new_prompts: List[List[int]] = [TreeGen.flatten(tokenized_prompt_list + ans) for ans in tokenized_ans_list]
        completion_prob_dict: Dict[Tuple[int], float] = TreeGen.remove_duplicates(new_prompts, prob_list)
        if num_groups == 1:
            shorten_completions = {k[:num_tokens]: v for k, v in completion_prob_dict.items()}
            return shorten_completions
        new_completions: Dict[Tuple[int], float] = dict()
        for curr_new_prompt, curr_new_prompt_prob in completion_prob_dict.items():
            curr_completions: Dict[Tuple[int], float] = self.rec_gen(curr_new_prompt, num_tokens - self.group_size, curr_new_prompt_prob)
            tokens: Tuple[int]
            prob: float
            for tokens, prob in curr_completions.items():
                new_completions[tokens] = prob
        return new_completions


    def __call__(self, prompt: str, num_new_tokens: int) -> str:
        tokenized_prompt_ten = self.tokenizer(prompt, return_tensors="pt")
        if isinstance(tokenized_prompt_ten, dict) or isinstance(tokenized_prompt_ten, BatchEncoding):
            if "input_ids" in tokenized_prompt_ten.keys():
                tokenized_prompt_ten = tokenized_prompt_ten["input_ids"]
        
        final_num_tokens = tokenized_prompt_ten.shape[1] + num_new_tokens

        tokenized_prompt: List[int] = tokenized_prompt_ten.tolist()
        seq_prob_dict: Dict[Tuple[int], float] = self.rec_gen(tokenized_prompt, num_new_tokens)
        highest_prob_seq: Tuple[int] = max(seq_prob_dict, key=seq_prob_dict.get)
        decoded_prompt = self.tokenizer.decode(highest_prob_seq[:final_num_tokens])
        return decoded_prompt

In [7]:
def compare_generators(grouped_generator: TextGenerator, non_grouped_generator: TextGenerator, prompt: str, num_tokens: int):
    """Compares grouped and non-grouped text generators"""
    print(f"Your prompt:")
    print(prompt)

    start_non_grouped = timeit.default_timer()
    non_grouped_ans: str = non_grouped_generator(prompt, num_tokens)
    stop_non_grouped = timeit.default_timer()
    non_grouped_time = stop_non_grouped - start_non_grouped
    print(f"Text generated by Non grouped sampling in {non_grouped_time} seconds:")
    print(non_grouped_ans)

    start_grouped_generation = timeit.default_timer()
    grouped_ans: str = grouped_generator(prompt, num_tokens)
    stop_grouped_generation = timeit.default_timer()
    grouped_time = stop_grouped_generation - start_grouped_generation
    print(f"Text generated by grouped sampling in {grouped_time} seconds:")
    print(grouped_ans)

# Use grouped sampling:

In [9]:
model_name = "bigscience/bloom-1b1" #@param ["facebook/opt-125m", "facebook/opt-1.3b", "facebook/opt-350m", "gpt2", "gpt2-medium", "gpt2-large", "bigscience/bloom-1b1", "bigscience/bloom-560m"]
prompt = "I was so sad to graduate high school that I almost " #@param {type:"string"}
num_tokens = 4 #@param {type:"integer", min:1}
top_p = 0.4 #@param {type:"slider", min:0.0, max:1.0, step:0.05}
top_k = 10000000000000000000000 #@param {type:"integer"}
group_size = 2 #@param {type:"integer"}
temperature = 1 #@param {type:"number", min:0.000000001}
sampling_type = 'from distribution' #@param ["from distribution", "as tree"]


vocab_size = AutoConfig.from_pretrained(model_name).vocab_size

num_groups = ceil(num_tokens / group_size)
actual_top_k = min(top_k, vocab_size * top_p)

if sampling_type == "as tree":
    max_num_calls = sum((actual_top_k ** i) for i in range(num_groups))    
    previous_methods_max_num_calls = sum((actual_top_k ** i) for i in range(num_tokens))
    grouped_generator = TreeGen(model_name, group_size, top_k, top_p, temperature)
    non_grouped_generator = TreeGen(model_name, 1, top_k, top_p, temperature)
else:
    max_num_calls = num_groups
    previous_methods_max_num_calls = num_tokens
    if top_p == 1.0:
        top_p = None
    elif top_p == 0.0:
        top_k = 1
        top_p = None
    elif top_k >= vocab_size:
        top_k = None
    else:
        raise ValueError("When using sampling from distribution, You must use either top k or top k and no both")
    grouped_generator = SampleGen(model_name, group_size, top_k, top_p, temperature)
    non_grouped_generator = SampleGen(model_name, 1, top_k, top_p, temperature)

print(f"The model will generate {num_tokens} tokens (words or parts of words)")
print(f"It will call the model at most {max_num_calls} times to the model")
print(f"Previous methods will need up to {previous_methods_max_num_calls} call to the model to generate the same text")

compare_generators(grouped_generator, non_grouped_generator, prompt, num_tokens)

The model will generate 4 tokens (words or parts of words)
It will call the model at most 2 times to the model
Previous methods will need up to 4 call to the model to generate the same text
Your prompt:
I was so sad to graduate high school that I almost 
Text generated by Non grouped sampling in 2.255551921999995 seconds:
I was so sad to graduate high school that I almost  couldn’t stand it
Text generated by grouped sampling in 1.18751023599998 seconds:
I was so sad to graduate high school that I almost  missed did itt


# Change the hyper-parameters to see what will happen!

In [22]:
global_tokenizer = non_grouped_generator.tokenizer


([170675], 'PAD')

In [35]:
tokenized_prompt_ten = global_tokenizer(prompt, return_tensors="pt")
if isinstance(tokenized_prompt_ten, dict) or isinstance(tokenized_prompt_ten, BatchEncoding):
    if "input_ids" in tokenized_prompt_ten.keys():
        tokenized_prompt_ten = tokenized_prompt_ten["input_ids"]
tokenized_prompt_ten.shape[1]

12