In [1]:
from rosemary import jpt_parse_args, jpt_setup, jpt_in_notebook; jpt_setup()

if jpt_in_notebook():
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5'
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    

  warn(f'Install `torch` for functionalities dependent on torch')


In [2]:
from functools import partial
import os
import sys
import numpy as np
import time

import pickle
from tqdm import tqdm 

import torch
from torch.utils.data import DataLoader

from datasets import load_dataset

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import DataCollatorForSeq2Seq

from open_instruct.finetune_trainer import encode_with_prompt_completion_format, encode_with_messages_format

import sys
sys.path.insert(0, "/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/fast-map-dpp")
from dpp import dpp


[2023-09-23 11:25:11,087] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
model_name_or_path = '../results/baselines/huggyllama/llama-7b'

train_file = '../data/processed/all.jsonl'
train_file = '../data/processed/flan_v2/flan_v2_data.jsonl'


save_dir = '/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts'
save_path = os.path.join(save_dir, 'note_explore_dpp_llama-7b_flan_v2_outputs.pkl')
gen_embeddings = False

In [4]:

data_files = {'train': train_file}
raw_datasets = load_dataset("json", data_files=data_files)
print(len(raw_datasets['train']))


Found cached dataset json (/gpfs/u/scratch/PTFM/PTFMqngp/huggingface_cache/datasets/json/default-0fa1872b4ac9a29f/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

100000


In [5]:

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map='auto',
    torch_dtype=torch.float16)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id


Using pad_token, but it is not set yet.


In [7]:
encode_function = partial(
    encode_with_messages_format,
    tokenizer=tokenizer,
    max_seq_length=2048,
)

lm_datasets = raw_datasets.map(
    encode_function,
    batched=False,
    num_proc=16,
    desc="Tokenizing and reformatting instruction data",
)

train_dataset = lm_datasets['train']

train_dataset.set_format(type="torch", 
                         output_all_columns=False, 
                         columns=['input_ids', 'labels', 'attention_mask'])
train_dataset

Loading cached processed dataset at /gpfs/u/scratch/PTFM/PTFMqngp/huggingface_cache/datasets/json/default-0fa1872b4ac9a29f/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96/cache-3a8173652244b47d_*_of_00016.arrow


Dataset({
    features: ['dataset', 'id', 'messages', 'input_ids', 'labels', 'attention_mask'],
    num_rows: 100000
})

In [8]:
if gen_embeddings:
    
    # collate_fn = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding='longest') 
    loader = DataLoader(train_dataset, shuffle=False, batch_size=1) 

    device = 'cuda'

    text_embeddings = []
    log_probs = []

    for batch in tqdm(loader, total=len(loader)):
        batch = {k: v.to('cuda', non_blocking=True) for k, v in batch.items()}
        input_ids = batch['input_ids']

        with torch.inference_mode():
            outputs = model(**batch, output_hidden_states=True)

        # (bsz, seq_len, hidden_size) -> (bsz, hidden_size)
        text_embedding = outputs['hidden_states'][-1].mean(1)

        # sum of output token log probs
        log_prob = -outputs['loss']

        text_embeddings.append(text_embedding.detach().cpu().numpy())
        log_probs.append(log_prob.detach().cpu().numpy())

In [9]:
if gen_embeddings:
    d = {'text_embeddings': np.vstack(text_embeddings),
         'log_probs': np.vstack(log_probs)}

    with open(save_path, 'wb') as f:
        pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open(save_path, 'rb') as f:
        d = pickle.load(f)
        
# some entries are nan.
d['log_probs'] = np.nan_to_num(d['log_probs'], nan=np.nanmean(d['log_probs']))
text_embeddings = d['text_embeddings']
log_probs = d['log_probs'].squeeze()

In [51]:
## sort by increasing/decreasing logprob. 

save_path = os.path.join(save_dir, 'note_explore_dpp_llama-7b_flan_v2_subsets_prob_incr.pkl')
inds = np.argsort(log_probs)
inds = inds.tolist()
with open(save_path, 'wb') as f:
    pickle.dump({'K': inds}, f, protocol=pickle.HIGHEST_PROTOCOL)
        
save_path = os.path.join(save_dir, 'note_explore_dpp_llama-7b_flan_v2_subsets_prob_decr.pkl')
inds = inds[::-1]
with open(save_path, 'wb') as f:
    pickle.dump({'K': inds}, f, protocol=pickle.HIGHEST_PROTOCOL)

In [17]:
# save_path = os.path.join(save_dir, 'note_explore_dpp_llama-7b_flan_v2_subsets_prob_increasing.pkl')

# with open(save_path, 'rb') as f:
#     out = pickle.load(f)
    


(10000,)

In [14]:

# N = 10000
text_embeddings = text_embeddings[:N,:]
log_probs = log_probs[:N]

In [None]:
# get an idea of log-prob vs. length. see if there is bias & need to normalize by length.
# - no obvious correlation between #tokens/seq and log P_LM(seq).
# - true larger number of tokens, the log probability is larger

num_tokens = []
for i in range(N):
    n = len(train_dataset[i]['input_ids'])
    num_tokens.append(n)

fig, ax = plt.subplots(1,1,figsize=(5,5))
ax.scatter(num_tokens, log_probs)
ax.set_xlabel('#Tokens / Seq')
ax.set_ylabel('log P_LM(Seq)')
# ax.set_xlim((0,250))

In [None]:
## what are low/high p_LM(seq) examples?
#  - low: input prompt has excotic knowledge, or need to come up with exotic knowledge in the answer.
#  - high: output already exists in the input prompt (copy 1 sentence, remove punctuation, summarization)

K = 10
inds = np.argsort(log_probs)
inds = inds[-K:] # top k
inds = inds[:K]  # bot k

for k in range(K):
    ind = int(inds[k])
    print('\n'+'==='*20+f' k={k},ind={str(ind)},logprob={log_probs[ind]:.3f}')
    print(tokenizer.decode(train_dataset[ind]['input_ids']))

In [91]:
if True:
    T = torch.from_numpy(text_embeddings).to('cuda').to(torch.float32)
    logP = torch.from_numpy(log_probs).to('cuda').to(torch.float32)
    
    T = torch.nn.functional.normalize(T, dim=-1)
    ## out-of-memory
    # S = T@T.T
    ## block-wise matmul to reduce peak memory usage.
    L = []
    for Tn in torch.split(T, 10000):
        L.append((Tn@T.T).to('cpu'))
    S = torch.vstack(L)
    P = logP.exp().to('cpu')

    K_cos = S
    K_cos_prob = P.reshape(N,1)*S*P.reshape(1,N)
    K_cos_oneminusprob = (1-P).reshape(N,1)*S*(1-P).reshape(1,N)
    
    K_cos = K_cos.to('cpu').numpy()
    K_cos_prob = K_cos_prob.to('cpu').numpy()
    K_cos_oneminusprob = K_cos_oneminusprob.to('cpu').numpy()
    
else:
    text_embeddings /= np.linalg.norm(text_embeddings, axis=-1, keepdims=True)
    similarities = np.dot(text_embeddings, text_embeddings.T) # cosine sim
    probs = np.exp(log_probs)

    K_cos = similarities 
    K_cos_prob = probs.reshape(N, 1) * similarities * probs.reshape(1, N)
    K_cos_oneminusprob = (1-probs).reshape(N, 1) * similarities * (1-probs).reshape(1, N)

In [95]:

def cholesky_jitter(K, jitter=1e-5):
    K[np.diag_indices_from(K)] += jitter
    return K

def cholesky_jitter_variable(K, jitters=[0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]):
    for v in jitters:
        try:
            Kt = cholesky_jitter(K, v)
            np.linalg.cholesky(Kt)
            print(v)
            return Kt
        except:
            continue
    
    raise ValueError()


print('add jitter ensures matrices are psd')
# cholesky O(N^3) too costly
# K_cos = cholesky_jitter_variable(K_cos)
# K_cos_prob = cholesky_jitter_variable(K_cos_prob)
# K_cos_oneminusprob = cholesky_jitter_variable(K_cos_oneminusprob)
K_cos = cholesky_jitter(K_cos, 1e-3)
K_cos_prob = cholesky_jitter(K_cos_prob, 1e-3)
K_cos_oneminusprob = cholesky_jitter(K_cos_oneminusprob, 1e-3)
print('add jitter ensures matrices are psd done!')


add jitter ensures matrices are psd
add jitter ensures matrices are psd done!


In [None]:
import matplotlib.pyplot as plt
from rosemary import plt_kernel_matrix_one

fig, axs = plt.subplots(1,3,figsize=(15,5))
ax = axs[0]
plt_kernel_matrix_one(fig, ax, K_cos)
ax = axs[1]
plt_kernel_matrix_one(fig, ax, K_cos_prob)
ax = axs[2]
plt_kernel_matrix_one(fig, ax, K_cos_oneminusprob)

In [None]:


Ks = {'K_cos': K_cos, 'K_cos_prob': K_cos_prob, 'K_cos_oneminusprob': K_cos_oneminusprob}
for kernel_matrix_name, K in Ks.items():
    out = {}
    s = time.time()
    inds = dpp(K, N) # select till triggers stopping criterion
    print(f'running: {kernel_matrix_name} has len={len(inds)} cost {time.time()-s:.2f} seconds')
    out['K'] = inds

    save_path = os.path.join(save_dir, f'note_explore_dpp_llama-7b_flan_v2_subsets_{kernel_matrix_name}.pkl')
    with open(save_path, 'wb') as f:
        pickle.dump(out, f, protocol=pickle.HIGHEST_PROTOCOL)

In [10]:
save_path = os.path.join(save_dir, 'note_explore_dpp_llama-7b_flan_v2_subsets_K_cos.pkl')

with open(save_path, 'rb') as f:
    out = pickle.load(f)

[29,
 21924,
 82981,
 20682,
 74114,
 91727,
 20684,
 77109,
 98159,
 43958,
 28715,
 21928,
 28954,
 5276,
 56223,
 73964,
 80457,
 84107,
 98131,
 87156,
 292,
 768,
 20458,
 31390,
 89270,
 97040,
 58409,
 5,
 92897,
 97869,
 81966,
 20214,
 2153,
 2992,
 89287,
 20008,
 77237,
 79666,
 4930,
 78565,
 21822,
 98077,
 237,
 20277,
 276,
 85517,
 20860,
 1162,
 20117,
 20254,
 45004,
 22375,
 88522,
 98473,
 80182,
 74573,
 40145,
 20071,
 90197,
 56221,
 23789,
 97989,
 91814,
 56333,
 78946,
 20480,
 56275,
 56588,
 41359,
 79308,
 21404,
 6592,
 60798,
 100,
 21746,
 20190,
 6371,
 432,
 28823,
 80975,
 1854,
 445,
 20467,
 1251,
 73384,
 46940,
 23622,
 308,
 56228,
 77651,
 20387,
 85372,
 395,
 20611,
 92382,
 20043,
 41033,
 98183,
 97964,
 40364,
 220,
 94398,
 318,
 20162,
 98386,
 1909,
 40654,
 44423,
 98427,
 20452,
 98244,
 21724,
 97826,
 98416,
 22619,
 98001,
 98032,
 7584,
 98318,
 90681,
 24370,
 20436,
 83708,
 67088,
 80127,
 56886,
 1080,
 4883,
 76791,
 22714,
 2