In [47]:
from rosemary import jpt_in_notebook
from llm.submit import submit_job, multiline_to_singleline

shell_scripts_template = """
echo "Running on $SLURM_JOB_NODELIST"
echo "======"

master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
master_port=10002
RDZV_ENDPOINT=$master_addr:$master_port

source ~/.profile
conda activate open-instruct
cd /gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/open-instruct/scripts

set -e
set -x
echo "======"
srun {cmd}

[ ! -f "{log_dir}/$SLURM_JOB_ID*.out" ] && mv {log_dir}/$SLURM_JOB_ID*.out {save_dir}
"""
nodes = 1; gpus=6

test_run = 0
test_run = bool(test_run)


model_name = 'llama-7b'; model_name_or_path = '../results/baselines/huggyllama/llama-7b'
# model_name = 'llama-7b_ft=hmv1'; model_name_or_path = '../results/ft1/llama-7b_humanmix'

save_dir = f"/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/model_outputs/{model_name}"
log_dir = '/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/'

# datasets = ['cot', 'dolly', 'flan_v2', 'lima', 'oasst1']; nodes = 1; gpus=6; cpu_mem = 512
# datasets = ['tulu_v1_human_mix', 'tulu_v2_human_mix']; nodes = 1; gpus=1; cpu_mem = 64
datasets = ['flan_v2']; nodes = 1; gpus=6; cpu_mem = 512
# datasets = ['flan2022_1m']; nodes = 2; gpus=6; cpu_mem = 512

## for testing
# datasets = ['lima']; nodes = 1; gpus=2; cpu_mem = 512



for dataset in datasets:
    cmd = f"""
    torchrun --nnodes={nodes} --nproc_per_node={gpus} \
        --rdzv-id=$SLURM_JOB_ID --rdzv-backend=c10d --rdzv-endpoint=$RDZV_ENDPOINT \
        note_llama_embeddings.py \
        --dataset {dataset} \
        --model_name_or_path {model_name_or_path} \
        --save_dir {save_dir} \
        --use_dist \
        --shuffle
    """
    cmd = multiline_to_singleline(cmd)

    shell_scripts = shell_scripts_template.format(
        cmd=cmd, log_dir=log_dir, save_dir=save_dir)
    out = submit_job(
        shell_scripts, 
        job_name=f'LM_outputs.{dataset}', 
        nodes=nodes,
        num_cpus=32,
        cpu_mem=cpu_mem,
        num_gpus=gpus,
        gpu_type='v100',
        test_run=test_run,
        job_duration=6,
    )
    print(cmd)
    if not test_run:
        print(out)


Submiting job with:
{
    "job_name": "LM_outputs.flan_v2",
    "nodes": 1,
    "num_cpus": 32,
    "cpu_mem": 512,
    "num_gpus": 6,
    "gpu_type": "v100",
    "test_run": false,
    "queue": "el8",
    "num_jobs": 1
}
torchrun --nnodes=1 --nproc_per_node=6 --rdzv-id=$SLURM_JOB_ID --rdzv-backend=c10d --rdzv-endpoint=$RDZV_ENDPOINT note_llama_embeddings.py --dataset flan_v2 --model_name_or_path ../results/baselines/huggyllama/llama-7b --save_dir /gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/model_outputs/llama-7b --use_dist --shuffle
[{'args': 'sbatch --job-name=LM_outputs.flan_v2 --partition=el8 --nodes=1 --ntasks-per-node=1 --cpus-per-task=32 --mem=512GB --gres=gpu:6 --output=/gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/open-instruct/scripts/%J.out --time=6:00:00 /gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/open-instruct/scripts/tmp1z28kjtb', 'job_id': 994464}]


In [None]:
p = '/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/model_outputs/llama-7b/lima.pkl'
with open(p, 'rb') as f:
    x = pickle.load(f)
x

In [None]:

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1,2,figsize=(10,5))

ax = axs[0]
ax.plot(np.exp(x['log_probs']), label='probs')
ax.plot(x['el2ns'], label='el2n')
ax.legend()

ax = axs[1]
ax.scatter(np.exp(x['log_probs']), x['el2ns'])
ax.set_xlabel('prob')
ax.set_ylabel('el2n')


In [None]:
import os

processed_dir = '../data/processed'
datasets = []
for dataset in os.listdir(processed_dir) + ['tulu_v1_human_mix', 'tulu_v2_human_mix']:
    dataset_path = os.path.join(processed_dir, dataset)
    save_path = os.path.join(save_dir, f'{dataset}.pkl')
    if 'tulu'==dataset:
        continue
    if 'tulu' not in dataset and not os.path.isdir(dataset_path):
        continue
    if os.path.isfile(save_path):
        continue
    datasets.append(dataset)
    
datasets
    

In [2]:
from rosemary import jpt_parse_args, jpt_setup, jpt_in_notebook; jpt_setup()

if jpt_in_notebook():
    import os
    
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
#     os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
    

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
from collections import defaultdict
from functools import partial
import os
import numpy as np
import time
import random
import pickle
from tqdm import tqdm 

import pyarrow # import before `torch`, `transformers`, `datasets`
import torch
from torch.utils.data import DataLoader

from datasets import load_dataset

from transformers import AutoModelForCausalLM, AutoTokenizer

from open_instruct.finetune_trainer import encode_with_prompt_completion_format, encode_with_messages_format
from note_llama_embeddings import combine_lm_outputs_for_mixes, datasets_shard_chunk_size, compute_losses


[2023-10-05 22:52:37,807] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [11]:
test_run = True
shuffle = False

model_name = 'llama-7b'; model_name_or_path = '../results/baselines/huggyllama/llama-7b'
model_name = 'pythia-1.4b'; model_name_or_path = '../results/baselines/EleutherAI/pythia-1.4b'

save_dir = f"/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/open-instruct/scripts/model_outputs/{model_name}"
os.makedirs(save_dir, exist_ok=True)

In [12]:
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map='cuda:0',
    torch_dtype=torch.float16)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path, use_fast=True)
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

Using pad_token, but it is not set yet.


In [13]:
dataset = 'lima'
# dataset = 'flan_v2'

use_dist = False
shuffle = True


if dataset in ['tulu_v1_human_mix', 'tulu_v2_human_mix']:
    combine_lm_outputs_for_mixes(dataset, save_dir)


if use_dist:
    dist.init_process_group("gloo", timeout=datetime.timedelta(hours=6))
    world_size = dist.get_world_size()
    rank = dist.get_rank() # global rank
    local_rank = int(os.environ["LOCAL_RANK"])
else:
    rank = 0
    local_rank = 0
    world_size = 1

print(f'rank/local_rank/world_size: {rank}/{local_rank}/{world_size}\n')

device = f'cuda:{str(local_rank)}'


rank/local_rank/world_size: 0/0/1



In [14]:
processed_dir = '../data/processed'
if 'flan2022' in dataset:
    train_file = os.path.join(processed_dir, 'flan2022', f'{dataset}_data.jsonl')
else:
    train_file = os.path.join(processed_dir, dataset, f'{dataset}_data.jsonl')
assert(os.path.isfile(train_file))


encode_function = partial(
    encode_with_messages_format, tokenizer=tokenizer, max_seq_length=2048)

if rank == 0:
    raw_datasets = load_dataset("json", data_files={'train': train_file})
    # if test_run:
    #     raw_datasets['train'] = raw_datasets['train'].select(range(100))
    print(f"{dataset} dataset length = {len(raw_datasets['train'])}")
    lm_datasets = raw_datasets.map(
        encode_function, batched=False, num_proc=16,
        desc="Tokenizing and reformatting instruction data")
if use_dist:
    dist.barrier()
if rank!= 0:
    raw_datasets = load_dataset("json", data_files={'train': train_file})
    # if test_run:
    #     raw_datasets['train'] = raw_datasets['train'].select(range(100))
    print(f"{dataset} dataset length = {len(raw_datasets['train'])}")
    lm_datasets = raw_datasets.map(
        encode_function, batched=False, num_proc=16,
        desc="Tokenizing and reformatting instruction data")

    

train_dataset = lm_datasets['train']
train_dataset.set_format(
    type="torch",
    output_all_columns=False,
    columns=['input_ids', 'labels', 'attention_mask'])
if shuffle:
    random.seed(0)
    shuffle_inds = list(range(len(train_dataset)))
    random.shuffle(shuffle_inds)
    reverse_shuffle_inds = [(i, ind) for i, ind in enumerate(shuffle_inds)]
    reverse_shuffle_inds = sorted(reverse_shuffle_inds, key=lambda x: x[1])
    reverse_shuffle_inds = [x[0] for x in reverse_shuffle_inds]
    train_dataset = train_dataset.select(shuffle_inds)
train_dataset_chunk_sizes = [datasets_shard_chunk_size(len(train_dataset), num_shards=world_size, index=i) 
            for i in range(world_size)]
train_dataset = train_dataset.shard(
    num_shards=world_size, 
    index=rank,
    contiguous=True)
loader = DataLoader(train_dataset, shuffle=False, batch_size=1, pin_memory=True) 

Found cached dataset json (/gpfs/u/scratch/PTFM/PTFMqngp/huggingface_cache/datasets/json/default-1ca1bac0eed76345/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

lima dataset length = 1030


Tokenizing and reformatting instruction data (num_proc=16):   0%|          | 0/1030 [00:00<?, ? examples/s]

In [30]:

output = defaultdict(list)
for batch in tqdm(loader, disable=rank!=0, total=len(loader)):
    batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
#     with torch.inference_mode():
    outputs = model(**batch, output_hidden_states=True)

    # (bsz, seq_len, hidden_size) -> (bsz, hidden_size)
    text_embedding = outputs['hidden_states'][-1].mean(1)
    # average of output token log probs
    log_prob = -outputs['loss']
    # compute EL2N score
    losses = compute_losses(outputs['logits'], batch['labels'])

    output['text_embedding'].append(text_embedding.detach().cpu().to(torch.float32))
    output['log_prob'].append(log_prob.detach().cpu())
    for k in ['el2n_agg=mean', 'el2n_agg=l2n', 'logit_margin']:
        output[k].append(losses[k].detach().cpu())

    break
    
for k, v in output.items():
    output[k] = torch.vstack(v).to(torch.float32).numpy()

  0%|          | 0/1030 [00:00<?, ?it/s]


In [None]:

print(f'[local_rank/global={local_rank}/{rank}] '
      f'output={[(k, v.shape) for k, v in output.items()]}')


In [40]:


# outputs['loss'].backward()
# logits[0, class_idx].backward(retain_graph=True)

# bsz, seq, dim
outputs['logits'][0,0,1000].backward(retain_graph=True)




In [41]:


grads = []

for param_name, param in model.named_parameters():
    if param.requires_grad and param.grad is not None:
        print(f"Gradient for {param_name}: {param.grad.shape}")
        grads.append(param.grad.squeeze())


Gradient for gpt_neox.embed_in.weight: torch.Size([50304, 2048])
Gradient for gpt_neox.layers.0.input_layernorm.weight: torch.Size([2048])
Gradient for gpt_neox.layers.0.input_layernorm.bias: torch.Size([2048])
Gradient for gpt_neox.layers.0.post_attention_layernorm.weight: torch.Size([2048])
Gradient for gpt_neox.layers.0.post_attention_layernorm.bias: torch.Size([2048])
Gradient for gpt_neox.layers.0.attention.query_key_value.weight: torch.Size([6144, 2048])
Gradient for gpt_neox.layers.0.attention.query_key_value.bias: torch.Size([6144])
Gradient for gpt_neox.layers.0.attention.dense.weight: torch.Size([2048, 2048])
Gradient for gpt_neox.layers.0.attention.dense.bias: torch.Size([2048])
Gradient for gpt_neox.layers.0.mlp.dense_h_to_4h.weight: torch.Size([8192, 2048])
Gradient for gpt_neox.layers.0.mlp.dense_h_to_4h.bias: torch.Size([8192])
Gradient for gpt_neox.layers.0.mlp.dense_4h_to_h.weight: torch.Size([2048, 8192])
Gradient for gpt_neox.layers.0.mlp.dense_4h_to_h.bias: torch.Si

In [42]:
grads = [x.reshape(-1,1) for x in grads]
grads = torch.vstack(grads)
grads.shape, torch.sum(grads*grads)

(torch.Size([1414647808, 1]),
 tensor(inf, device='cuda:0', dtype=torch.float16))

In [43]:
grads

OutOfMemoryError: CUDA out of memory. Tried to allocate 5.27 GiB (GPU 0; 31.75 GiB total capacity; 22.86 GiB already allocated; 4.03 GiB free; 26.80 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [24]:

# Compute gradients for all logits at once
class_indices = torch.arange(logits_all_classes.size(1))  # Assuming logits are in the second dimension
logit_gradients = torch.autograd.grad(outputs=logits_all_classes[:, class_indices], inputs=model.parameters(), retain_graph=True)


0.5625

In [None]:
from note_llama_embeddings import compute_losses

labels = batch['labels']
logits = outputs['logits']


compute_losses(logits, labels)

In [None]:

if logits.shape[0]!=1:
    raise ValueError('compute_el2n supports bsz=1 only.')
vocab_size = logits.shape[-1]
device = logits.device
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
# (Bsz*|Seq|, |Vocab|)
shift_logits = shift_logits.view(-1, vocab_size)
shift_probs = torch.nn.functional.softmax(shift_logits, dim=-1)
shift_labels = shift_labels.view(-1)
# only compute loss on the output tokens
output_tok_indices = (shift_labels != -100).nonzero().reshape(-1)
shift_labels = shift_labels[output_tok_indices]
shift_probs = shift_probs[output_tok_indices]
shift_logits = shift_logits[output_tok_indices]
# Enable model parallelism
shift_labels = shift_labels.to(device)

losses = {}
# Compute EL2N = || prob - one-hot-label ||_2
shift_probs_minus_onehot_target = shift_probs.clone()
shift_probs_minus_onehot_target[torch.arange(shift_probs.size(0)), shift_labels] -= 1
loss_tokenwise = torch.linalg.norm(shift_probs_minus_onehot_target, dim=-1, ord=2)
losses['el2n_agg=mean'] = loss_tokenwise.mean()
losses['el2n_agg=l2n'] =  torch.linalg.norm(loss_tokenwise, ord=2)
# Classification logit margin
shift_logits_true = torch.gather(shift_logits, 1, shift_labels.view(-1, 1)).squeeze()
shift_logits_other = shift_logits.clone()
shift_logits_other[torch.arange(shift_logits.size(0)), shift_labels] = float('-inf')
shift_logits_other_max, _ = torch.max(shift_logits_other, 1)
losses['logit_margin'] = (shift_logits_true-shift_logits_other_max).mean()


# Compute

In [None]:
i = 0
output = {k: [] for k in output_keys}
for batch in tqdm(loader, disable=rank!=0, total=len(loader)):
    batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
    with torch.inference_mode():
        outputs = model(**batch, output_hidden_states=True)


    labels = batch['labels']
    logits = outputs['logits']



    if logits.shape[0]!=1:
        raise ValueError('compute_el2n supports bsz=1 only.')
    vocab_size = logits.shape[-1]
    device = logits.device
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    # (Bsz*|Seq|, |Vocab|)
    shift_logits = shift_logits.view(-1, vocab_size)
    shift_probs = torch.nn.functional.softmax(shift_logits, dim=-1)
    shift_labels = shift_labels.view(-1)
    # only compute loss on the output tokens
    output_tok_indices = (shift_labels != -100).nonzero().reshape(-1)
    shift_labels = shift_labels[output_tok_indices]
    shift_probs = shift_probs[output_tok_indices]
    shift_logits = shift_logits[output_tok_indices]
    # Enable model parallelism
    shift_labels = shift_labels.to(device)


    losses = {}
    # Compute EL2N = || prob - one-hot-label ||_2
    shift_probs_minus_onehot_target = shift_probs.clone()
    shift_probs_minus_onehot_target[torch.arange(shift_probs.size(0)), shift_labels] -= 1
    loss_tokenwise = torch.linalg.norm(shift_probs_minus_onehot_target, dim=-1, ord=2)
    losses['el2n_agg=mean'] = loss_tokenwise.mean()
    losses['el2n_agg=l2n'] =  torch.linalg.norm(loss_tokenwise, ord=2)

    shift_logits_true = torch.gather(shift_logits, 1, shift_labels.view(-1, 1)).squeeze()
    shift_logits_other = shift_logits.clone()
    shift_logits_other[torch.arange(shift_logits.size(0)), shift_labels] = float('-inf')
    shift_logits_other_max, _ = torch.max(shift_logits_other, 1)
    losses['logit_margin'] = (shift_logits_true-shift_logits_other_max).mean()
    
    if i == 3:
        break
    i += 1

In [None]:
shift_logits_other_max.max()

In [None]:
# try KL(prob_true, prob_nottrue) over tokens in sequence. 