In [2]:
!pip install nltk

Collecting nltk
  Downloading nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Collecting joblib (from nltk)
  Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Downloading nltk-3.9.1-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hUsing cached joblib-1.4.2-py3-none-any.whl (301 kB)
Installing collected packages: joblib, nltk
Successfully installed joblib-1.4.2 nltk-3.9.1


In [6]:
!ls

diversity_gen.py    model_outputs     try_combined_decoding.ipynb
diversity.ipynb     __pycache__       try_load_model_logits.ipynb
intra_diversity.py  README.md	      utils.py
logit_lens.ipynb    requirements.txt


In [None]:
from intra_diversity import *
import pickle
from transformers import AutoTokenizer
import torch
import matplotlib.pyplot as plt
import numpy as np
import random
from utils import *

In [8]:
post_samples_path = './model_outputs/ContextualAI_archangel_sft_pythia2-8b_euclaise_writingprompts_validation_samples_100.pkl'
base_samples_path = './model_outputs/EleutherAI_pythia-2.8b_euclaise_writingprompts_validation_samples_100.pkl'
#here the samples are generated in a teacher forcing manner, so the model is not generating the text, but rather the text is being generated by a teacher model

with open(post_samples_path, 'rb') as f:
    post_samples = pickle.load(f)
with open(base_samples_path, 'rb') as f:
    base_samples = pickle.load(f)

tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-2.8b')

: 

In [None]:
def get_combined_top_k_for_one_token(sample_a, sample_b,  token_pos, k=10, sample_a_name='Sample A', sample_b_name='Sample B'):
    """
    Compare the top k token prob distribution of two samples.
    Note: The distribution plot might have [k, 2k] tokens, 
          these are combined top-k logits from both samples.
    """
    probs_a = sample_a['generated_logits'].softmax(dim=-1)[token_pos].unsqueeze(0)
    probs_b = sample_b['generated_logits'].softmax(dim=-1)[token_pos].unsqueeze(0)

    top_k_a = probs_a.topk(k, dim=-1).indices
    top_k_b = probs_b.topk(k, dim=-1).indices

    # combine to get all the token ids (unique)
    combined_top_k = torch.unique(torch.cat([top_k_a, top_k_b], dim=-1)).tolist()
    combined_probs_a = probs_a[0, combined_top_k].numpy()
    combined_probs_b = probs_b[0, combined_top_k].numpy()
    
    #measure sharpness by calc the cumulative sum of the probs
    combined_probs_a_cumsum = np.cumsum(combined_probs_a)
    combined_probs_b_cumsum = np.cumsum(combined_probs_b)
    
    #a distribution is sharp if the top 3 cumsum >= 0.9
    a_is_sharp = combined_probs_a_cumsum[2] >= 0.9
    b_is_sharp = combined_probs_b_cumsum[2] >= 0.9
    
    if a_is_sharp and b_is_sharp:
        print(f"Both {sample_a_name} and {sample_b_name} are sharp.")
        #here we will sample from a (SFT version)
        next_token = random.choices(top_k_a) #[token_id]
    elif a_is_sharp and not b_is_sharp:
        print(f"{sample_a_name} is sharp, but {sample_b_name} is not.")
        #here we will sample from combined top k
        next_token = random.choices(combined_top_k)
    else:
        #here we sample from a (SFT version)
        print(f"{sample_a_name} is not sharp, but {sample_b_name} is.")
        next_token = random.choices(top_k_a)
    return next_token[0]
    

In [None]:
def get_single_top_k(sample,token_pos, k=10):
    """
    Get the top k token prob distribution of a sample.
    """
    probs = sample['generated_logits'].softmax(dim=-1)[token_pos].unsqueeze(0)
    top_k = probs.topk(k, dim=-1).indices
    return random.choices(top_k)[0]

In [None]:
def diverse_one_token(sample_sft,sample_base,token_pos,tokenizer):
    """
    for one dp,  given teacher-forcing logits sample_sft and sample_base,
    sample the next token with diversity
    """
    story_prefix_ids = tokenizer(sample_sft["story"]).input_ids[:token_pos]
    story_prefix = tokenizer.decode(story_prefix_ids)
    
    next_token = get_combined_top_k_for_one_token(sample_sft, sample_base,  token_pos, k=10, sample_a_name='SFT', sample_b_name='Base')
    next_token_sft = get_single_top_k(sample_sft, token_pos, k=10)
    next_token_base = get_single_top_k(sample_base, token_pos, k=10)
    
    combined_seq = story_prefix + tokenizer.decode(next_token)
    combined_seq_sft = story_prefix + tokenizer.decode(next_token_sft)
    combined_seq_base = story_prefix + tokenizer.decode(next_token_base)
    
    #print table: sample method, prefix, next token
    print(f"{'sample method':<20} {'prefix':<50} {'next token':<20}")
    print(f"{'combined':<20} {story_prefix:<50} {tokenizer.decode(next_token):<20}")
    print(f"{'sft':<20} {story_prefix:<50} {tokenizer.decode(next_token_sft):<20}")
    print(f"{'base':<20} {story_prefix:<50} {tokenizer.decode(next_token_base):<20}")