In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
from llama import Workflow, Llama

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29502"

workflow = Workflow.build(
    ckpt_dir='/scratch4/jeisner1/tjbai/llama_8b',
    tokenizer_path='/scratch4/jeisner1/tjbai/llama_8b/tokenizer.model',
    max_seq_len=4*8192,
    max_batch_size=1,
    model_parallel_size=1,
    max_nodes=100,
    use_lora=True,
    lora_rank=64,
    lora_alpha=32,
    lora_dropout=0.05,
)

llama = Llama(workflow.model, workflow.tokenizer)



> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Converting to LoRA
Loaded in 13.40 seconds


In [4]:
from llama.util import load_ckpt

load_ckpt(workflow, '/scratch4/jeisner1/tjbai/checkpoints/bsm/lora_step-104.pt')

In [8]:
from llama.workflows.bsm import bsm_cached

outputs = bsm_cached(
    workflow=workflow,
    concepts=concepts_list[1],
    compact=True
)

In [4]:
from llama.workflows.bsm import bsm_baseline

baseline_outputs = bsm_baseline(
    workflow=workflow,
    concepts=concepts_list[0],
)

sample = {
    'inputs': {'concepts': concepts_list[0]},
    'outputs': baseline_outputs,
}

In [None]:
for merge_tokens in baseline_outputs['merge_tokens']:
    print(workflow.tokenizer.decode(merge_tokens))

In [10]:
from llama.workflows.finetune import BsmTrainer

trainer = BsmTrainer(
    workflow=workflow,
    output_dir='/scratch4/jeisner1/tjbai/checkpoints/bsm/',
    learning_rate=1e-5,
)

Training 54.5M / 8.1B parameters


In [11]:
trainer.step(sample)

(tensor(2.9656, device='cuda:0', dtype=torch.float32, grad_fn=<AddBackward0>),
 {'train/branch_loss': tensor(0.4616, device='cuda:0', dtype=torch.float32,
         grad_fn=<NllLossBackward0>),
  'train/solve_loss': tensor(0.8012, device='cuda:0', dtype=torch.float32,
         grad_fn=<NllLossBackward0>),
  'train/merge_loss': tensor(1.7028, device='cuda:0', dtype=torch.float32,
         grad_fn=<NllLossBackward0>)})

In [2]:
import json

def clean(ls):
    return [a[a.index('\n\n')+2:] for a in ls]

with open('/home/tbai4/llama3/dumps/bsm/preft_eval.json') as f:
    data = json.load(f)['raw_data']
    baseline = clean(data['baseline']['stories'])
    cached = clean(data['cached']['stories'])
    cached_compact = clean(data['cached_compact']['stories'])

In [18]:
from llama.workflows.bsm import compare_stories, load_concepts

concepts_list = load_concepts(
    data_path='/home/tbai4/llama3/data/commongen/commongen.jsonl',
    split='val'
)

_, a_results = compare_stories(baseline, cached, concepts_list)
print(json.dumps(a_results, indent=2))

_, b_results = compare_stories(baseline, cached_compact, concepts_list)
print(json.dumps(b_results, indent=2))

Comparing: 100%|██████████| 50/50 [03:18<00:00,  3.97s/it]

{
  "a_wins": 28,
  "b_wins": 6,
  "ties": 16,
  "errors": 0,
  "total": 50,
  "a_win_percent": 56.00000000000001,
  "b_win_percent": 12.0,
  "tie_percent": 32.0
}





In [19]:
with open('/home/tbai4/llama3/dumps/bsm/checkpoints/all_checkpoint_results.json') as f:
    data = json.load(f)

In [25]:
postft = clean(data['lora_step-104']['raw_data']['stories'])
_, c_results = compare_stories(baseline, postft, concepts_list)
print(json.dumps(c_results, indent=2))

Comparing: 100%|██████████| 50/50 [03:27<00:00,  4.15s/it]

{
  "a_wins": 15,
  "b_wins": 14,
  "ties": 21,
  "errors": 0,
  "total": 50,
  "a_win_percent": 30.0,
  "b_win_percent": 28.000000000000004,
  "tie_percent": 42.0
}



