In [1]:
# !pip install scipy -i https://mirrors.aliyun.com/pypi/simple/
# !pip install transformers==4.33.3 -i https://mirrors.aliyun.com/pypi/simple/

In [1]:
import warnings 
warnings.filterwarnings("ignore")
import os
import sys
from typing import List

import fire
import torch
import transformers
from datasets import load_dataset

"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer

from utils.prompter import Prompter_TTS, Prompter_TTS_Phonemes

In [2]:
# model/data params
# base_model: str = 'decapoda-research/llama-7b-hf' 
base_model: str = '../HuggingFace-Download-Accelerator/hf_local/models--baffo32--decapoda-research-llama-7B-hf/snapshots/aa18b48a1330572a6dd5f5d5619ed19838ca285c/'
# data_path: str = "../HuggingFace-Download-Accelerator/hf_hub/datasets--yahma--alpaca-cleaned/alpaca_data_cleaned.json"
data_path: str = "./datasets/tts_data_train_1.json"
# data_path: str = "./datasets/tts_data_train_1_small.json"
output_dir: str = "./result/lora-tts-1_128_text"
# training hyperparams
# batch_size: int = 128
batch_size: int = 64
micro_batch_size: int = 8
num_epochs: int = 10
learning_rate: float = 3e-4
cutoff_len: int = 512
val_set_size: int = 2000
# val_set_size: int = 300
# lora hyperparams
lora_r: int = 128
lora_alpha: int = 128
lora_dropout: float = 0.05
# lora_target_modules: List[str] = [ "q_proj", "k_proj", "v_proj", "o_proj"]
lora_target_modules: List[str] = ["q_proj", "k_proj", "v_proj", "o_proj"]
# llm hyperparams
train_on_inputs: bool = True  # if False masks out inputs in loss
add_eos_token: bool = False
group_by_length: bool = True  # faster but produces an odd training loss curve
# wandb params
wandb_project: str = ""
wandb_run_name: str = ""
wandb_watch: str = ""  # options: false | gradients | all
wandb_log_model: str = ""  # options: false | true
# resume_from_checkpoint: str = None  # either training checkpoint or final adapter
resume_from_checkpoint: str = "./result/lora-tts-1_128_text/checkpoint-1800/"
prompt_template_name: str = "alpaca_tts_txt_enc"  # The prompt template to use will default to alpaca.
encodec_dim: int = 1024
# encodec_nq: int = int(os.path.basename(data_path).split(".")[0][-1])
encodec_nq: int = 1

def tokenize(prompt, add_eos_token=True):
    # there's probably a way to do this with the tokenizer settings
    # but again, gotta move fast
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=cutoff_len,
        padding=False,
        return_tensors=None,
    )
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < cutoff_len
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)

    result["labels"] = result["input_ids"].copy()

    return result

def generate_and_tokenize_prompt(data_point):
    full_prompt = prompter.generate_prompt(
        # data_point["text"],
        data_point["text"],
        data_point["output"],
    )
    tokenized_full_prompt = tokenize(full_prompt)
    if not train_on_inputs:
        user_prompt = prompter.generate_prompt(
            data_point["text"]
        )
        tokenized_user_prompt = tokenize(
            user_prompt, add_eos_token=add_eos_token
        )
        user_prompt_len = len(tokenized_user_prompt["input_ids"])

        if add_eos_token:
            user_prompt_len -= 1

        tokenized_full_prompt["labels"] = [
            -100
        ] * user_prompt_len + tokenized_full_prompt["labels"][
            user_prompt_len:
        ]  # could be sped up, probably
    return tokenized_full_prompt

In [3]:
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
    print(
        f"Training Alpaca-LoRA model with params:\n"
        f"base_model: {base_model}\n"
        f"data_path: {data_path}\n"
        f"output_dir: {output_dir}\n"
        f"batch_size: {batch_size}\n"
        f"micro_batch_size: {micro_batch_size}\n"
        f"num_epochs: {num_epochs}\n"
        f"learning_rate: {learning_rate}\n"
        f"cutoff_len: {cutoff_len}\n"
        f"val_set_size: {val_set_size}\n"
        f"lora_r: {lora_r}\n"
        f"lora_alpha: {lora_alpha}\n"
        f"lora_dropout: {lora_dropout}\n"
        f"lora_target_modules: {lora_target_modules}\n"
        f"train_on_inputs: {train_on_inputs}\n"
        f"add_eos_token: {add_eos_token}\n"
        f"group_by_length: {group_by_length}\n"
        f"wandb_project: {wandb_project}\n"
        f"wandb_run_name: {wandb_run_name}\n"
        f"wandb_watch: {wandb_watch}\n"
        f"wandb_log_model: {wandb_log_model}\n"
        f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
        f"prompt template: {prompt_template_name}\n"
    )
assert (
    base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size

prompter = Prompter_TTS_Phonemes(prompt_template_name)

device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
    device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
    gradient_accumulation_steps = gradient_accumulation_steps // world_size

# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
    "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
    os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
    os.environ["WANDB_LOG_MODEL"] = wandb_log_model

Training Alpaca-LoRA model with params:
base_model: ../HuggingFace-Download-Accelerator/hf_local/models--baffo32--decapoda-research-llama-7B-hf/snapshots/aa18b48a1330572a6dd5f5d5619ed19838ca285c/
data_path: ./datasets/tts_data_train_1.json
output_dir: ./result/lora-tts-1_128_text
batch_size: 64
micro_batch_size: 8
num_epochs: 10
learning_rate: 0.0003
cutoff_len: 512
val_set_size: 2000
lora_r: 128
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj']
train_on_inputs: True
add_eos_token: False
group_by_length: True
wandb_project: 
wandb_run_name: 
wandb_watch: 
wandb_log_model: 
resume_from_checkpoint: ./result/lora-tts-1_128_text/checkpoint-1800/
prompt template: alpaca_tts_txt_enc



In [4]:
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.pad_token_id = (
    0  # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left"  # Allow batched inference
special_tokens = {"additional_special_tokens": ["<SPCH>", "</SPCH>"]}
tokenizer.add_special_tokens(special_tokens)
new_tokens = [f"<{t}>" for t in range(encodec_dim*encodec_nq)]
tokenizer.add_tokens(new_tokens, False)
f = open("phoneme_vocab.txt", "r")
phonemes = f.read()
special_tokens = {"additional_special_tokens": ["<PHN>", "</PHN>"]}
tokenizer.add_special_tokens(special_tokens, False)
new_phonemes = [f"<{t.upper()}>" for t in  phonemes.split("\n")]
tokenizer.add_tokens(new_phonemes)

model = LlamaForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)
torch.manual_seed(42)
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
model = prepare_model_for_int8_training(model)

config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=lora_target_modules,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    # modules_to_save= ["base_model.model.model.embed_tokens.weight", "base_model.model.lm_head.weight"],
)
model = get_peft_model(model, config)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Loading checkpoint shards: 100%|██████████| 33/33 [00:15<00:00,  2.15it/s]


In [5]:
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
    data = load_dataset("json", data_files=data_path)
else:
    data = load_dataset(data_path)

if resume_from_checkpoint:
    # Check the available weights and load them
    checkpoint_name = os.path.join(
        resume_from_checkpoint, "pytorch_model.bin"
    )  # Full checkpoint
    if not os.path.exists(checkpoint_name):
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "adapter_model.bin"
        )  # only LoRA model - LoRA config above has to fit
        resume = (
            False  # So the trainer won't try loading its state
        )
    # The two files above have a different name depending on how they were saved, but are actually the same.
    
    m = torch.load(resume_from_checkpoint + "embedding_layer.pt")
    model.base_model.model.model.embed_tokens.load_state_dict(m)
    m = torch.load(resume_from_checkpoint + "lm_head.pt")
    model.base_model.model.lm_head.parameters(m)
    
    if os.path.exists(checkpoint_name):
        print(f"Restarting from {checkpoint_name}")
        adapters_weights = torch.load(checkpoint_name)
        set_peft_model_state_dict(model, adapters_weights)
    else:
        print(f"Checkpoint {checkpoint_name} not found")
else:
    resume = False
        
        
for parameter in model.base_model.model.model.embed_tokens.parameters():
    parameter.requires_grad = True
for parameter in model.base_model.model.lm_head.parameters():
    parameter.requires_grad = True
    
model.print_trainable_parameters()  # Be more transparent about the % of trainable params.

Restarting from ./result/lora-tts-1_128_text/checkpoint-1800/adapter_model.bin
trainable params: 405,536,768 || all params: 6,881,808,384 || trainable%: 5.892880844268506


In [6]:
from transformers import TrainerCallback

class CustomSaveCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        if state.global_step % args.save_steps == 0:
            save_dir = output_dir + f"/checkpoint-{state.global_step}/"
            torch.save(model.base_model.model.model.embed_tokens.state_dict(), save_dir + 'embedding_layer.pt')
            torch.save(model.base_model.model.lm_head.state_dict(), save_dir + 'lm_head.pt')

In [7]:
if val_set_size > 0:
    train_val = data["train"].train_test_split(
        test_size=val_set_size, shuffle=True, seed=42
    )
    train_data = (
        train_val["train"].shuffle().map(generate_and_tokenize_prompt)
    )
    val_data = (
        train_val["test"].shuffle().map(generate_and_tokenize_prompt)
    )
else:
    train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
    val_data = None

if not ddp and torch.cuda.device_count() > 1:
    # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
    model.is_parallelizable = True
    model.model_parallel = True

trainer = transformers.Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    callbacks=[CustomSaveCallback()],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=micro_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=100,
        num_train_epochs=num_epochs,
        learning_rate=learning_rate,
        fp16=True,
        logging_steps=10,
        optim="adamw_torch",
        evaluation_strategy="steps" if val_set_size > 0 else "no",
        save_strategy="steps",
        eval_steps=50 if val_set_size > 0 else None,
        save_steps=50,
        output_dir=output_dir,
        save_total_limit=3,
        load_best_model_at_end=True if val_set_size > 0 else False,
        ddp_find_unused_parameters=False if ddp else None,
        group_by_length=group_by_length,
        report_to="wandb" if use_wandb else None,
        run_name=wandb_run_name if use_wandb else None,
    ),
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
)
model.config.use_cache = False

# old_state_dict = model.state_dict
# model.state_dict = (
#     lambda self, *_, **__: get_peft_model_state_dict(
#         self, old_state_dict()
#     )
# ).__get__(model, type(model))

if torch.__version__ >= "2" and sys.platform != "win32":
    model = torch.compile(model)

Map: 100%|██████████| 38365/38365 [01:13<00:00, 524.44 examples/s]
Map: 100%|██████████| 2000/2000 [00:03<00:00, 532.82 examples/s]


In [None]:
trainer.train(resume_from_checkpoint=resume)

model.save_pretrained(output_dir)

print(
    "\n If there's a warning about missing keys above, please disregard :)"
)

Step,Training Loss,Validation Loss
50,2.49,3.073784
100,2.3871,3.074851
150,2.3972,3.08419
200,2.3779,3.057469
250,2.3981,3.08135
300,2.3757,3.061077
350,2.4013,3.047813
400,2.4475,2.993873
450,2.3804,2.990377
500,2.3673,3.001789
