In [8]:
import os
import itertools
from rosemary import jpt_in_notebook
from llm.submit import submit_job, multiline_to_singleline

shell_scripts_template = """
echo "Running on $SLURM_JOB_NODELIST"
echo "======"

master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
master_port=10002
RDZV_ENDPOINT=$master_addr:$master_port

source ~/.profile
conda activate open-instruct
cd /gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/open-instruct/scripts

set -e
set -x
echo "======"
srun {cmd}

[ ! -f "{log_dir}/$SLURM_JOB_ID*.out" ] && mv {log_dir}/$SLURM_JOB_ID*.out {save_dir}
"""
log_dir = '/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/'


test_run = 1
test_run = bool(test_run)

# model_name = 'llama-7b'
# model_name = 'llama-7b+lora:r=256:a=256' # includes lora grad norm
# # model_name = 'llama-7b_ft=hmv1' # means llama-7b finetuned on tulu humanmix already.
# sort_by_list = [
#     'random_s=0', 
#     'log_prob', 'logit_margin', 'el2n_agg=mean',  'el2n_agg=l2n',
#     'grad_loraB_l2n',
#     'dppmap_k=Kcos', 'dppmap_k=Kcosp', 'dppmap_k=Kcos1np', 
#     'kmeansl2_nc=3000', 'kmeanscd_nc=3000',
# ]
# dataset_list = ['flan2022_1m']



model_name = 'pythia-1b-deduped'
model_name = 'pythia-1b-deduped+lora:r=256:a=256'
dataset_list = ['cot', 'dolly', 'flan_v2', 'lima', 'oasst1']
# sort_by_list = ['random_s=0', 
#                 'log_prob', 'logit_margin', 'el2n_agg=mean', 'el2n_agg=l2n', 
#                 'kmeansl2_nc=3000', 'kmeanscd_nc=3000',
#                 'grad_loraB_l2n',
#                 'grad_all_l2n', 'grad_qkv_l2n', 'grad_mlp_l2n', 'grad_last_l2n',
#                ]
sort_by_list = ['grad_loraB_l2n']

from note_pruning_analysis import lm_output_dir, data_inds_dir
save_dir = os.path.join(data_inds_dir, model_name)
lm_output_dir = os.path.join(lm_output_dir, model_name)

options_list = itertools.product(dataset_list, sort_by_list)

print('test_run =',test_run)
cmds = []
for dataset, sort_by in options_list:
    cmd = f"""
     python note_pruning.py \
        --dataset {dataset} \
        --sort_by {sort_by} \
        --lm_output_dir {lm_output_dir} \
        --save_dir {save_dir} \
    """.strip()
    cmd = multiline_to_singleline(cmd)
    shell_scripts = shell_scripts_template.format(
        cmd=cmd, log_dir=log_dir, save_dir=save_dir)
    out = submit_job(
        shell_scripts, 
        job_name=f'prune.{dataset}.{sort_by}', 
        nodes=1,
        num_cpus=64, # 32
        cpu_mem=256, # 128
        num_gpus=1,
        gpu_type='v100',
        test_run=test_run,
        job_duration=6,
    )
    cmds.append(cmd)
    print(cmd)
        
print('#cmds: ', len(cmds))


test_run = True

Submiting job with:
{
    "job_name": "prune.cot.grad_loraB_l2n",
    "nodes": 1,
    "num_cpus": 64,
    "cpu_mem": 256,
    "num_gpus": 1,
    "gpu_type": "v100",
    "test_run": true,
    "queue": "el8",
    "num_jobs": 1
}
python note_pruning.py --dataset cot --sort_by grad_loraB_l2n --lm_output_dir /gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/model_outputs/pythia-1b-deduped+lora:r=256:a=256 --save_dir /gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/data_inds/pythia-1b-deduped+lora:r=256:a=256

Submiting job with:
{
    "job_name": "prune.dolly.grad_loraB_l2n",
    "nodes": 1,
    "num_cpus": 64,
    "cpu_mem": 256,
    "num_gpus": 1,
    "gpu_type": "v100",
    "test_run": true,
    "queue": "el8",
    "num_jobs": 1
}
python note_pruning.py --dataset dolly --sort_by grad_loraB_l2n --lm_output_dir /gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/model_outpu

In [26]:
path = '/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/model_outputs/pythia-1b-deduped'

from note_pruning_analysis import lm_output_dir, lm_inds_dir

output = get_lm_output(dataset, model_name, return_text_embedding=False)
output.keys()


dict_keys(['log_prob', 'el2n_agg=mean', 'el2n_agg=l2n', 'logit_margin', 'grad_all_l2n', 'grad_qkv_l2n', 'grad_mlp_l2n', 'grad_last_l2n'])

In [1]:
from rosemary import jpt_parse_args, jpt_setup, jpt_in_notebook; jpt_setup()

if jpt_in_notebook():
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5'
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    

  warn(f'Install `torch` for functionalities dependent on torch')


In [3]:
import os
import sys
import numpy as np
import time
import re
import random
import pickle
from tqdm import tqdm 

import pyarrow
import torch
import transformers

from note_pruning import (
    save_to_pickle,
    save_sorted_inds,
    sort_kmeans_dist_to_cluster_centers,
    sort_dpp_map,)

In [3]:
test_run = False
dataset = 'tulu_v1_human_mix'
dataset = 'tulu_v2_human_mix'
dataset = 'flan_v2'
dataset = 'lima'
dataset = 'flan2022_1m'

sort_by = 'random_s=0'
# sort_by = 'kmeansl2_nc=3000'
# sort_by = 'kmeanscd_nc=3000'
# sort_by = 'prob'
# sort_by = 'dppmap_k=Kcos'
# sort_by = 'dppmap_k=Kcos1np'
# sort_by = 'el2n'
# sort_by = 'grad_norm'

# used for generating model output.
# model_name = 'llama-7b'
# model_name = 'llama-7b_ft=hmv1'
model_name = 'llama-7b+lora:r=256:a=256'


save_dir = f"/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/data_inds/{model_name}/"
lm_output_dir = f'/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/model_outputs/{model_name}'
save_dir = os.path.join(save_dir, dataset)
os.makedirs(save_dir, exist_ok=True)

In [4]:

save_path = os.path.join(lm_output_dir, f'{dataset}.pkl')
with open(save_path, 'rb') as f:
    d = pickle.load(f)
if test_run:
    d = {k: v[:10000] for k, v in d.items()}

# some entries are nan, impute with mean value.
text_embedding = d['text_embedding']
N = text_embedding.shape[0]
log_prob = np.nan_to_num(d['log_prob'], nan=np.nanmean(d['log_prob'])).squeeze()
print(N)

1000000


In [18]:

sort_by = 'random_s=2'

t0 = time.time()
if sort_by in ['log_prob', 
               'el2n_agg=mean', 
               'el2n_agg=l2n', 
               'logit_margin', 
               'grad_loraB_l2n']:
    S = np.nan_to_num(d[sort_by], nan=np.nanmean(d[sort_by])).squeeze()
elif sort_by.startswith('random'):
    match = re.search(r's=(\d+)', sort_by)
    seed = int(match.group(1))
    random.seed(seed)
    inds = list(range(N))
    random.shuffle(inds)
if sort_by.startswith('kmeans'):
    dist_fn = 'l2' if sort_by.startswith('kmeansl2') else 'cd'
    match = re.search(r'(?<=\=)\d+', sort_by)
    n_clusters = int(match.group()) if match else None
    S = sort_kmeans_dist_to_cluster_centers(text_embedding, n_clusters, dist_fn=dist_fn)
elif sort_by.startswith('dpp'):
    match = re.search(r'k=(\w+)', sort_by)
    kernel_type = match.group(1) if match else None  
    inds = sort_dpp_map(text_embedding, log_prob, kernel_type=kernel_type)
t1 = time.time()
print(f'Rank datapoints with {sort_by} took {t1-t0:.2f} seconds.')

if any(sort_by.startswith(x) for x in ['dpp', 'random']):
    save_to_pickle(
        save_path=os.path.join(save_dir, f'{sort_by}.pkl'),
        output={'inds': inds})
else:
    save_sorted_inds(save_dir, S, sort_by, reverse=False)
    save_sorted_inds(save_dir, S, sort_by, reverse=True)

Rank datapoints with random_s=2 took 1.07 seconds.
save inds (length = 1000000) to /gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/data_inds/llama-7b+lora:r=256:a=256/flan2022_1m/random_s=2.pkl


[(145659, 552866),
 (826371, 261831),
 (431768, 505341),
 (822321, 20384),
 (402, 257637),
 (295572, 320721),
 (392523, 613443),
 (546540, 265513),
 (73810, 695315),
 (341544, 409663),
 (43184, 23722),
 (121847, 609697),
 (249174, 872026),
 (471953, 628022),
 (997753, 847615),
 (13161, 855738),
 (556097, 91765),
 (672306, 955575),
 (158228, 329080),
 (567277, 471333),
 (375384, 768796),
 (958220, 356521),
 (477198, 905235),
 (44901, 957753),
 (794131, 848508),
 (281887, 925108),
 (493409, 440672),
 (311079, 957097),
 (26662, 487193),
 (242831, 176813),
 (704039, 34575),
 (597049, 591496),
 (846875, 580703),
 (691301, 9042),
 (613094, 34262),
 (271036, 461601),
 (772065, 305721),
 (733423, 834198),
 (91220, 805294),
 (551477, 569507),
 (360449, 584360),
 (720887, 195381),
 (785226, 13153),
 (469446, 547379),
 (619490, 699900),
 (60110, 631617),
 (987622, 833445),
 (910765, 887740),
 (160139, 106438),
 (372608, 225119),
 (645936, 458641),
 (485672, 771946),
 (995665, 850643),
 (34055, 41

In [15]:
a = np.random.rand(10000,4096).astype(np.float32)
b = np.random.rand(10000,4096).astype(np.float32)
%timeit np.sum(a*b,axis=-1)

55.1 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
a = np.random.rand(10000,4096).astype(np.float64)
b = np.random.rand(10000,4096).astype(np.float64)
%timeit np.sum(a*b,axis=-1)

92.3 ms ± 141 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
