<a href="https://colab.research.google.com/github/yonikremer/final_project/blob/master/bloom_demo_working.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo - this is a demo for the project - 

Here yu can try out the sampling method I intreduced in my research.
I highly reccomend to lookup [my research project](https://github.com/yonikremer/final_project) for more information.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yonikremer/final_project/blob/master/bloom_demo_working.ipynb)

# Set Up

Please run the following cell once, than restart the kernal and run the other cells by order

In [1]:
%pip install -q transformers;
%pip install -q torch;

[K     |████████████████████████████████| 4.7 MB 5.0 MB/s 
[K     |████████████████████████████████| 101 kB 12.9 MB/s 
[K     |████████████████████████████████| 6.6 MB 42.8 MB/s 
[?25h

In [1]:
from typing import List, Tuple, Iterable, Set, Dict
from functools import reduce

from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import Softmax

# Important Notes:


*   Please make sure the model you choose is not too big for your hardware
*   gpt2 is the smallest version of gpt-2, gpt2-xl is the largest
*   bloom is the largest vesion of bloom



# Selecting a Model

In [None]:
#@title Select a model from the list
["facebook/opt-125m", "gpt2", "gpt2-medium", "gpt2-xl", "gpt2-large", "bloom", "bloom-7b1", "bloom-6b3", "bloom-3b", "bloom-1b7", "bloom-1b1", "bloom-560m"]

model_name = "bloom-560m" #@param ["facebook/opt-125m", "facebook/opt-1.3b", "facebook/opt-13b", "facebook/opt-125m", "facebook/opt-2.7b", "facebook/opt-30b", "facebook/opt-350m", "facebook/opt-6.7b", "facebook/opt-66b", "gpt2", "gpt2-medium", "gpt2-xl", "gpt2-large", "bloom", "bloom-7b1", "bloom-6b3", "bloom-3b", "bloom-1b7", "bloom-1b1", "bloom-560m"]
if "bloom" in model_name:
    model_name = f"bigscience/{model_name}"

In [None]:
tokenizer = AutoModelForCausalLM.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

Downloading tokenizer_config.json:   0%|          | 0.00/222 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/13.8M [00:00<?, ?B/s]

Downloading special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/710 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

# The code I use to sample:

In [None]:
soft_max_func = Softmax(dim=1)

def get_prob_mat(model_name, str_prompt, group_size):
    """Returns the probability matrix as a list of lists of floats"""
    inputs = tokenizer(str_prompt, return_tensors="pt")
    if "bloom" in model_name:
        logits_pt_tensor = model(**inputs, labels=inputs["input_ids"]).logits.squeeze(0)
    else:
        logits_pt_tensor = model(**inputs)
    prob_tensor = soft_max_func(logits_pt_tensor)
    try:
        prob_tensor = prob_tensor[:group_size, :]
    except IndexError as e:
        raise [e, ValueError("Group size must be smaller than the size of the prompt in tokens")]
    
    prob_mat = [prob_tensor[i, :].tolist() for i in range(group_size)]
    return prob_mat

def combinations(mat):
    """Returns all the lists such that list[j] is in mat[j]
    complexity: prod([len(mat[i]) for i in range(len(mat))])"""
    if len(mat) == 1:
        return [[mat[0][i]] for i in range(len(mat[0]))]
    res = []
    for i in mat[0]:
        for j in combinations(mat[1:]):
            res.append([i] + j)
    return res


def doesnt_have_duplicates(my_list):
    """Return if there isn't a repetition in the list
    complexity: O(n) where n is the length of the list"""
    return len(my_list) == len(set(my_list))


def seq_prob(tokens, prob_mat):
    """Given the probability matrix and a list of tokens
    returns the probability of the sequence
    prob_mat[a][b] is the probability of the token with id b the a-th token in the sequence
    comlexity: O(len(tokens))"""
    probability = 1.0
    sequence_length = len(tokens)
    for i in range(sequence_length):
        curr_token = tokens[i]
        probability *= prob_mat[i][curr_token]
    return probability


def grouped_sampling(prob_mat: List[List[int]], top_p, top_k, group_size) -> Tuple[List[List[int]], List[float]]:
    """given a matrix of probabilities, returns a list of lists of tokens
    the matrixs is of size group_size x vocab_size
    where matrix[i, j] is the probability of token j the i-th token in the group
    samples the tokens such that for each place in the group,
    at most top_k tokens are sampled and at least one token is sampled
    and the added probability of all the tokens is less than or equal top_p
    returns a list of where every item is a tuple of a sequense and probability
    overall complexity of the function in O(group_size * vocab_size * log(vocab_size))"""
    
    # prob_tensor.shape is now (group_size, vocab_size)
    indices = []  # O(1)
    for i in range(group_size): # group_size times
        token_prob = prob_mat[i]  # O(1)
        vocab_size = len(token_prob)  # O(1)
        indexed_prob = list(zip(token_prob, range(vocab_size)))  # O(vocab_size)
        sorted_indexed_prob = sorted(indexed_prob, key=lambda x: x[0], reverse=True)[:top_k] # O(vocab_size*log(vocab_size))
        total_prob: float = sorted_indexed_prob[0][0]  # O(1)
        curr_indices: List[List[int]] = [sorted_indexed_prob[0][1]]
        for prob, token in sorted_indexed_prob:  # top_k times
            if total_prob + prob > top_p:  # O(1)
                break
            total_prob += prob  # O(1)
            curr_indices.append(token)  # O(1)
        indices.append(curr_indices)  # O(1)
    new_sequences: List[List[int]] = combinations(indices)  # theta(prod(len(indices[i]) for i in range(group_size)))
    # len(indices[i]) <= min(top_k, vocab_size)
    # therefore the complexity is O(min(top_k, vocab_size) * group_size)
    filtered_sequences: Iterable[List[int]] = list(filter(doesnt_have_duplicates, new_sequences))
    prob_list: List[Tuple[List[int], float]] = [seq_prob(seq, prob_mat) for seq in filtered_sequences]
    return filtered_sequences, prob_list


def flatten(l: list) -> List[int]:
    # Complexity: O(len(the flatten list))
    new_list = []
    for item in l:
        if isinstance(item, int):
            new_list.append(item)
        elif isinstance(item, list):
            new_list.extend(flatten(item))
    return new_list


def remove_duplicates(completions: List[List[int]], probs: List[float]) -> Dict[Tuple[int], float]:
    """Given a list of tokenized answers and the probability of each complition,
    removes every repeated completion and every complition that have repeated tokens"""
    filtered_completions: Dict[Tuple[int], float] = dict()
    for curr_comp, curr_prob in zip(completions, probs):
        if len(curr_comp) == len(set(curr_comp)):
            curr_comp_tuple = tuple(curr_comp)
            filtered_completions[curr_comp_tuple] = curr_prob
    return filtered_completions


def complete(model_name, org_prompt, top_p, top_k, num_groups, group_size, org_prompt_prob = 1.0) -> Dict[Tuple[int], float]:
    """preprocess the prompt and completes the text
    model name: str
    org_prompt: str"""
    if isinstance(org_prompt, str):
        str_prompt = org_prompt
        tokenized_prompt_ten = tokenizer(str_prompt, return_tensors="pt")["input_ids"]
        tokenized_prompt_list = tokenized_prompt_ten.tolist()
    elif isinstance(org_prompt, list):
        tokenized_prompt_list = flatten(org_prompt)
        str_prompt = tokenizer.decode(tokenized_prompt_list)
    else:
        raise ValueError("org_prompt must be a string or list of integers")
    
    prob_mat = get_prob_mat(model_name, str_prompt, group_size)
    tokenized_ans_list, prob_list = grouped_sampling(prob_mat, top_p, top_k, group_size)
    new_prompts: List[List[int]] = [tokenized_prompt_list + ans for ans in tokenized_ans_list]
    if num_groups == 1:
        return remove_duplicates(new_prompts, prob_list)
    all_new_completions: List[List[int]] = []
    all_new_probs: List[float] = []
    for curr_new_prompt, curr_new_prompt_prob in zip(new_prompts, prob_list):
        curr_new_prompt_prob *= org_prompt_prob
        curr_completions = complete(model_name, curr_new_prompt, top_p, top_k, num_groups - 1, group_size, curr_new_prompt_prob)
        for tokens, prob in curr_completions.items():
            all_new_completions.append(tokens)
            all_new_probs.append(prob)
    return remove_duplicates(all_new_completions, all_new_probs)


# Use grouped sampling:

##How to use previous methods:

| brute force         | nucleus (top p)                 | top k                              | greedy         |
|---------------------|---------------------------------|------------------------------------|----------------|
| top p=0             | 0 < top p < 1                   | top p = 1                          | top p = 0      |
| top k >= vocab_size | top k >= the model's vocab size | 1 < top k < the model's vocab size | top k = 1      |
| group_size          | group_size = 1                  | group_size = 1                     | group_size = 1 |

In [None]:
prompt = 'A large langange model is a ' #@param {type:"string"}
top_p = 1 #@param {type:"slider", min:0.0, max:1.0, step:0.05}
top_k = 1000000000 #@param {type:"integer"}
group_size = 1 #@param {type:"integer"}
num_groups = 10 #@param {type:"integer"}

max_num_calls = sum((top_k ** i) for i in range(num_groups))
num_tokens_genrated = num_groups * group_size
pervious_num_calls = sum((top_k ** i) for i in range(num_tokens_genrated))

print(f"The model will generate {num_tokens_genrated} tokens (words or parts of words)")
print(f"It will call the model less than {max_num_calls} calls to the model")
print(f"Previous methods will need {max_num_calls} call to the model to genrate the same text")

In [None]:
completions = complete(model_name, prompt, top_p, top_k, num_groups, group_size)


if len(completions.keys()) > 0:
    grouped_sampling_best_comp_tokenized = max(completions, key=completions.get)
else:
    raise ValueError("""
        The model could not generate a complition with the constrains of the sampling method, U should try:
        1) Increase top p and/or top k
        2) Use diffrent (larger) model
        3) Decrease the number of groups
        4) Decrease the group size
        """)
    
grouped_sampling_best_comp = tokenizer.decode(grouped_sampling_best_comp_tokenized)
print("The best text generated by grouped sampling is:")
print(grouped_sampling_best_comp)

ValueError: ignored

In [None]:
sorted_complitions = dict(sorted(completions.keys, key = completions.get))
print("Here are all the complitions created by predicted probability")
for tokens in sorted_complitions:
    print(tokenizer.decode(tokens))

# Change the hyper parameters to see what will happend!