In [2]:
import os, sys, json, random, io, pytz, argparse, re
sys.path.append("../scripts/llama/")
os.environ['CUDA_VISIBLE_DEVICES'] = "3"
os.environ['HF_HOME'] = '/data/yingshac/hf_cache'

import numpy as np
from tqdm import tqdm
import torch
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from torch.utils.data import DataLoader
from peft import (
        get_peft_model, 
        prepare_model_for_kbit_training, 
        LoraConfig, 
        PeftModel,
        AutoPeftModelForCausalLM
    )
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from pprint import pprint
from config import FinetuningConfig
config = FinetuningConfig()

  from .autonotebook import tqdm as notebook_tqdm


[2024-03-30 22:13:19,501] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)




In [3]:
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [48]:
prompt_path = "../data/feed_decoder_LM/regular/mode_2nd/q_prompt.txt"
with open(prompt_path, "r") as f: prompt_template = "\n".join(f.readlines()).strip()
def preprocess(prompt_template, input_str, answer=None, eos_token="</s>"):
    prompt = prompt_template.format(input_str)
    response = f"{' ' + str(answer) + '.' + eos_token if answer else ''} "
    text = "### Question: {}\n ### Answer:{}".format(prompt, response) #(" ").join([prompt, response])
    return text

def formatting_func(instance):
    global prompt_template
    output = []
    for d, s in zip(instance["input_str"], instance["answer"]):
        op = preprocess(prompt_template, d, s)
        output.append(op)
    return output

In [6]:
bnb_config = BitsAndBytesConfig(
            load_in_4bit=False,
            load_in_8bit=True,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )

In [4]:
model_name = "meta-llama/Llama-2-7b-hf" 
my_hf_token = "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"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=my_hf_token)
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True # Turn-off an annoying warning msg
tokenizer.add_special_tokens({"pad_token":"<pad>"})
tokenizer.padding_side = 'right'

# model = LlamaForCausalLM.from_pretrained(
#     model_name,
#     torch_dtype=torch.float16,
#     device_map='auto', #{"": 0},
#     token=my_hf_token,
#     quantization_config=bnb_config,
# )
# model.resize_token_embeddings(len(tokenizer))


In [49]:
instance = {
    "input_str": "000000010101010",
    "answer": "+"
}
prompt = preprocess(prompt_template, instance["input_str"], instance["answer"])
print(json.dumps(prompt))
tokenized_full_prompt = tokenizer(
    prompt,
    truncation=True,
    max_length=config.max_seq_length,
    padding=False,
    return_tensors=None,
)
print(len(tokenized_full_prompt['input_ids']))

user_prompt = preprocess(prompt_template, instance["input_str"], None)
print(json.dumps(user_prompt))
tokenized_user_prompt = tokenizer(
        user_prompt,
        truncation=True,
        max_length=config.max_seq_length,
        padding=False,
        return_tensors=None,
    )
user_prompt_len = len(tokenized_user_prompt["input_ids"]) - 1
print(user_prompt_len)
#tokenized_full_prompt['input_ids'][user_prompt_len:]

"### Question: What is the second most frequent digit in the string \"000000010101010\"?\n ### Answer: +.</s> "
39
"### Question: What is the second most frequent digit in the string \"000000010101010\"?\n ### Answer: "
35


In [50]:
tokenized_full_prompt["input_ids"][user_prompt_len:]

[718, 29889, 2, 259]

In [51]:
tokenizer.convert_ids_to_tokens(tokenized_full_prompt["input_ids"][user_prompt_len:])

['▁+', '.', '</s>', '▁▁']

In [41]:
tokenizer.decode(tokenized_user_prompt['input_ids'])

'<s> ### Question: What is the second most frequent digit in the string "000000010101010"?\n ### Answer: '

In [5]:
data_dir = "../data/finetune/parity/uniform_split/"
train_data = load_dataset(data_dir, split="train")
val_data = load_dataset(data_dir, split="validation")
i = random.choice(list(range(len(val_data))))
print(preprocess(prompt_template, val_data[i]["input_str"], val_data[i]["answer"]))

### Question: What is the second most frequent digit in the string "1010010000100000110001000110000000001000000010000010000000000010100000000000000000"?
 ### Answer: +.</s> 


In [8]:
len(tokenizer.tokenize(
    "### Question: What is the second most frequent digit in the string \"554506024449650800083800583000208380000080376704044806346308300442808989030870058405084723440788644646898654442003804648230488838\"?"
    ))

144

In [6]:
tokenizer.tokenize("00010093587837472")

['▁',
 '0',
 '0',
 '0',
 '1',
 '0',
 '0',
 '9',
 '3',
 '5',
 '8',
 '7',
 '8',
 '3',
 '7',
 '4',
 '7',
 '2']

In [6]:
lora_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules=config.lora_target_modules,
        bias="none",
        task_type="CAUSAL_LM",
    )

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

print(f"Total parameters: {model.num_parameters()}")
print(f"Trainable parameters: {model.num_parameters(only_trainable=True)}")


Total parameters: 6746812416
Trainable parameters: 8388608


In [7]:
training_args = transformers.TrainingArguments(
            output_dir=os.path.join(config.output_dir, config.date, "ckpts"),
            per_device_train_batch_size=config.per_device_train_batch_size,
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            optim=config.optim,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            learning_rate=config.learning_rate,
            report_to="none",
            logging_steps=config.logging_steps, # Number of update steps between two logs if logging_strategy="steps". Should be an integer or a float in range [0,1). If smaller than 1, will be interpreted as ratio of total training steps.
            #eval_steps=config.eval_steps,
            max_grad_norm=config.max_grad_norm,
            num_train_epochs=config.num_train_epochs,
            warmup_steps=config.warmup_steps,
            group_by_length=True,
            lr_scheduler_type="constant_with_warmup",
            do_eval=True,
            ddp_find_unused_parameters=False,
            eval_accumulation_steps=config.eval_accumulation_steps,
            per_device_eval_batch_size=config.per_device_eval_batch_size,
        )

In [8]:
response_template_with_context = "\n ### Answer:"  # We added context here: "\n". This is enough for this tokenizer
response_template_ids = tokenizer(
    response_template_with_context,
    add_special_tokens=False,
).input_ids[2:]
print(f"response template tokens: {tokenizer.convert_ids_to_tokens(response_template_ids)}")
collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)

response template tokens: ['▁###', '▁Answer', ':']


In [9]:
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    em, count = 0, 0
    with open(os.path.join(config.output_dir, config.date, "eval_samples", f'{datetime.now(timezone).strftime("%m%d_%H%M%S")}.txt'), "w") as f:
        for pred, label in zip(preds, labels):
            # data collator would assign -100 to tokens who don't require loss calcupation
            p = 0
            while True:
                if label[p] != -100:
                    break
                p += 1
            
            a = [i for i in label[p:] if i != -100]
            o = [i for i in pred[p-1:] if i != -100]

            gth_response = tokenizer.decode(a, skip_special_tokens=True)
            pred_response = tokenizer.decode(o, skip_special_tokens=True)
            
            f.write(json.dumps([gth_response, pred_response])+"\n")
            
            gth_len = int(re.findall(r'(\d+)', gth_response)[0])
            pred_len = int(re.findall(r'(\d+)', pred_response)[0])
            
            em += int(gth_len == pred_len)
            count += 1
    return {'accuracy': em/count}


trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    peft_config=lora_config,
    formatting_func=formatting_func,
    data_collator=collator,
    max_seq_length=config.max_seq_length,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    args=training_args
)

Map: 100%|██████████| 800/800 [00:00<00:00, 33656.07 examples/s]


In [18]:
pprint(list(trainer.model.state_dict().keys()))


['base_model.model.model.embed_tokens.weight',
 'base_model.model.model.layers.0.self_attn.q_proj.weight',
 'base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight',
 'base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight',
 'base_model.model.model.layers.0.self_attn.k_proj.weight',
 'base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight',
 'base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight',
 'base_model.model.model.layers.0.self_attn.v_proj.weight',
 'base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight',
 'base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight',
 'base_model.model.model.layers.0.self_attn.o_proj.weight',
 'base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight',
 'base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight',
 'base_model.model.model.layers.0.mlp.gate_proj.weight',
 'base_model.model.model.layers.0.mlp.up_proj.weight',
 

In [12]:
pprint(trainer.model.peft_config['default'])

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='meta-llama/Llama-2-7b-hf', revision=None, task_type='CAUSAL_LM', inference_mode=False, r=8, target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], lora_alpha=8, lora_dropout=0.1, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None)


In [13]:
from peft.utils import get_peft_model_state_dict


In [37]:
input_text = preprocess(prompt_template, data[790]["input_str"])
input_tokens = tokenizer(input_text, return_tensors="pt")["input_ids"].to("cuda")
with torch.cuda.amp.autocast():
  generation_output = model.generate(
      input_ids=input_tokens,
      max_new_tokens=10,
      do_sample=False,
      #do_sample=True,
      #top_k=None,
      #top_p=0.9,
      #temperature=0,
      #repetition_penalty=1.15,
      num_return_sequences=1,
      eos_token_id=tokenizer.eos_token_id,
    )
op = tokenizer.decode(generation_output[0], skip_special_tokens=True)
print(op)

### Question: What is length of the string "0572674832"?
 ### Answer:  
```
12
```

##


In [23]:
for name, module in peft_model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)

In [39]:
input_text = preprocess(prompt_template, data[790]["input_str"])
input_tokens = tokenizer(input_text, return_tensors="pt")["input_ids"].to("cuda")
with torch.cuda.amp.autocast():
  generation_output = peft_model.generate(
      input_ids=input_tokens,
      max_new_tokens=5,
      do_sample=False,
      #top_k=10,
      #top_p=0.1,
      #temperature=0.3,
      #repetition_penalty=1.15,
      #num_return_sequences=1,
      eos_token_id=tokenizer.eos_token_id,
    )
op = tokenizer.decode(generation_output[0], skip_special_tokens=True)
print(op)



<s> ### Question: What is length of the string "0572674832"?
 ### Answer:  38.</s>


In [37]:
len("070193108739725")

15

### Sanity Check: It's indeed deterministic

In [20]:
import os, json
output_dir = "../data/output_decoder_LM/llama2-7b/len/zeroshot/"
ANS = []
for f in os.listdir(output_dir):
    if "1113_" in f:
        ANS.append([])
        lines = open(os.path.join(output_dir, f), "r").readlines()
        for l in lines:
            ANS[-1].append(json.loads(l)[-1].split("\n ### Answer:")[-1].strip())
print(len(ANS))

3


In [22]:
for i, j in zip(ANS[1], ANS[2]):
    if not i == j:
        print(i)
        print(j)

### Draft

In [7]:
# Concatenate datasets
from datasets import load_dataset, concatenate_datasets
train_data1 = load_dataset("../data/finetune/mode/uniform_split", split="train")
train_data2 = load_dataset("../data/finetune/mode/uniform_hard+_split", split="train")

dataset_cc = concatenate_datasets([train_data1, train_data2])


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 5343.06it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 46.68it/s]
Generating train split: 12300 examples [00:00, 427107.93 examples/s]
Generating validation split: 12300 examples [00:00, 500300.04 examples/s]
Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 6246.17it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 46.47it/s]
Generating train split: 12300 examples [00:00, 482916.21 examples/s]
Generating validation split: 12300 examples [00:00, 523803.59 examples/s]


In [14]:
len(dataset_cc)

24600

### Move ckpts to /data

In [6]:
import os
from tqdm import tqdm

In [7]:
output_dir = "../scripts/llama/output"
for exp_handle in tqdm(os.listdir(output_dir)):
    if "ckpts" not in os.listdir(os.path.join(output_dir, exp_handle)): continue
    ckpt_dir = f"/data/yingshac/llms_do_math/scripts/llama/output/{exp_handle}"
    #os.makedirs(ckpt_dir)
    command = f"mv {os.path.join(output_dir, exp_handle)}/ckpts {ckpt_dir}"
    #print(command)
    os.system(command)

100%|██████████| 19/19 [10:16<00:00, 32.43s/it]
