
Goal
- verify if there is correlation between P_LM with output token length
- analyze difference/similarity between sorted indices with different `sort_by`. see if they correlate with each other, e.g., first k item overlap
- visualize text statistics w.r.t. data subsets obtaind from pruning






In [1]:
from rosemary import jpt_parse_args, jpt_setup, jpt_in_notebook; jpt_setup()

  warn(f'Install `torch` for functionalities dependent on torch')


In [None]:
import os
import sys
import numpy as np
import time
import re
from functools import partial

import matplotlib.pyplot as plt

import pickle
from tqdm import tqdm 

import pyarrow
import torch
import transformers
from transformers import AutoTokenizer

from datasets import load_dataset

In [None]:
def get_lm_output(dataset):
    lm_output_dir = '/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/llama-7b_outputs'
    save_path = os.path.join(lm_output_dir, f'{dataset}.pkl')
    with open(save_path, 'rb') as f:
        output = pickle.load(f)
    output['log_probs'] = np.nan_to_num(output['log_probs'], nan=np.nanmean(output['log_probs']))
    return output


def get_dataset(dataset):
    processed_dir = '/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/data/processed'
    if 'tulu' in dataset:
        train_file = os.path.join(processed_dir, 'tulu', f'{dataset}.jsonl')
    else:
        train_file = os.path.join(processed_dir, dataset, f'{dataset}_data.jsonl')
    data_files = {'train': train_file}
    ds = load_dataset('json', data_files=data_files)['train']
    return ds



tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


def get_dataset_token_lengths(dataset, model_name_or_path, inds=None):
    from open_instruct.finetune_trainer import encode_with_messages_format

    ds = get_dataset(dataset)
    if inds is not None: ds = ds.select(inds)
    encode_fn = partial(encode_with_messages_format, tokenizer=tokenizer, max_seq_length=2048)
    ds = ds.map(encode_fn, batched=False, num_proc=16)
    ds.set_format(type='np')

    def count_token_lengths(d):
        x = d['labels']
        input_len = x[x==-100].shape[0]
        output_len = x.shape[0] - input_len
        return {'input_len': input_len, 'output_len': output_len}

    ds = ds.map(count_token_lengths, num_proc=16)
    return {'input_len': ds['input_len'], 'output_len': ds['output_len']}


def get_sorted_inds(dataset, sort_by):
    inds_dir = '/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/data_inds/llama-7b'
    save_path = os.path.join(inds_dir, dataset, f'{sort_by}.pkl')
    with open(save_path, 'rb') as f:
        output = pickle.load(f)
    return output

In [None]:
dataset = 'flan_v2'
model_name_or_path='../results/baselines/huggyllama/llama-7b'

dataset_list = ['baize', 'code_alpaca', 'cot', 
                'dolly', 'flan_v2', 'gpt4_alpaca', 
                'lima', 'oasst1', 'open_orca', 
                'self_instruct', 'sharegpt', 'stanford_alpaca', 
                'super_ni', 'unnatural_instructions', 'wizardlm',
                'tulu_v1_human_mix'
               ]


In [None]:
subsample_size = 2000


w = 5
fig, axs = plt.subplots(6,3, figsize=(w*3,w*6))


for axi, dataset in enumerate(dataset_list[:100]):
    
    ## get the information

    lm_output = get_lm_output(dataset)
    T = lm_output['text_embeddings']
    logP = lm_output['log_probs']

    if subsample_size:
        np.random.seed(0)
        inds = np.random.randint(0, T.shape[0], subsample_size)
        T = T[inds]
        logP = logP[inds]

    token_lengths = get_dataset_token_lengths(dataset, model_name_or_path, inds=inds if subsample_size else None)
    input_len = token_lengths['input_len']
    output_len = token_lengths['output_len']
    
    ## plot the information 
    ax = axs.flatten()[axi]

    ys = np.exp(logP)

    for label, xs in [
            ('input_len', input_len),
            ('output_len', output_len),
        ]:
        ax.scatter(xs, ys, label=label, alpha=.2)


    ax.legend(loc='upper center')
    eps = .001

    ax.set_ylim(np.quantile(ys, eps), np.quantile(ys, 1-eps))
    xs_for_xlim = np.maximum(input_len, output_len)
    ax.set_xlim(np.quantile(xs_for_xlim, eps), np.quantile(xs_for_xlim, 1-eps))
    
    # ax.set_xlabel('token lengths', fontsize=20)
    if axi%3 == 0:
        ax.set_ylabel('$p_{LM}$', fontsize=25)
    ax.set_title(dataset, fontsize=25)
    
    
fig.tight_layout()

In [None]:
fig.tight_layout()

In [None]:
fig.show()


In [None]:
ys.shape