In [1]:
# !pip install scipy -i https://mirrors.aliyun.com/pypi/simple/
# !pip install transformers==4.33.3 -i https://mirrors.aliyun.com/pypi/simple/

In [2]:
import warnings 
warnings.filterwarnings("ignore")
import os
import sys
from typing import List

import fire
import torch
import transformers
from datasets import load_dataset

"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import GenerationConfig

from utils.prompter import Prompter_TTS_Phonemes

In [3]:
# model/data params
# base_model: str = 'decapoda-research/llama-7b-hf' 
base_model: str = '../HuggingFace-Download-Accelerator/hf_local/models--baffo32--decapoda-research-llama-7B-hf/snapshots/aa18b48a1330572a6dd5f5d5619ed19838ca285c/'
# data_path: str = "../HuggingFace-Download-Accelerator/hf_hub/datasets--yahma--alpaca-cleaned/alpaca_data_cleaned.json"
# train_path: str = "./datasets/tts_data_train_1.json"
train_path: str = "./datasets/tts_data_train_1_small.json"
test_path: str = "./datasets/tts_data_test_1.json"
output_dir: str = "./result/lora-tts-1_phoenc"
# training hyperparams
# batch_size: int = 128
batch_size: int = 64
micro_batch_size: int = 8
num_epochs: int = 10
learning_rate: float = 3e-4
cutoff_len: int = 512
val_set_size: int = 2000
# val_set_size: int = 300
# lora hyperparams
lora_r: int = 16
lora_alpha: int = 16
lora_dropout: float = 0.05
# lora_target_modules: List[str] = [ "q_proj", "k_proj", "v_proj", "o_proj"]
lora_target_modules: List[str] = ["q_proj", "k_proj", "v_proj", "o_proj", "embed_tokens", "lm_head"]
# llm hyperparams
train_on_inputs: bool = True  # if False masks out inputs in loss
add_eos_token: bool = False
group_by_length: bool = True  # faster but produces an odd training loss curve
# wandb params
wandb_project: str = ""
wandb_run_name: str = ""
wandb_watch: str = ""  # options: false | gradients | all
wandb_log_model: str = ""  # options: false | true
# resume_from_checkpoint: str = None  # either training checkpoint or final adapter
resume_from_checkpoint: str = "./result/lora-tts-1_phoenc/checkpoint-5850/"
prompt_template_name: str = "alpaca_tts_pho_enc"  # The prompt template to use will default to alpaca.
encodec_dim: int = 1024
# encodec_nq: int = int(os.path.basename(data_path).split(".")[0][-1])
encodec_nq: int = 1

def tokenize(prompt, add_eos_token=True):
    # there's probably a way to do this with the tokenizer settings
    # but again, gotta move fast
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=cutoff_len,
        padding=False,
        return_tensors=None,
    )
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < cutoff_len
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)

    result["labels"] = result["input_ids"].copy()

    return result

def generate_and_tokenize_prompt(data_point):
    full_prompt = prompter.generate_prompt(
        data_point["phonemes"],
        data_point["output"],
    )
    tokenized_full_prompt = tokenize(full_prompt)
    if not train_on_inputs:
        user_prompt = prompter.generate_prompt(
            data_point["phonemes"]
        )
        tokenized_user_prompt = tokenize(
            user_prompt, add_eos_token=add_eos_token
        )
        user_prompt_len = len(tokenized_user_prompt["input_ids"])

        if add_eos_token:
            user_prompt_len -= 1

        tokenized_full_prompt["labels"] = [
            -100
        ] * user_prompt_len + tokenized_full_prompt["labels"][
            user_prompt_len:
        ]  # could be sped up, probably
    return tokenized_full_prompt

In [4]:
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
    print(
        f"Training Alpaca-LoRA model with params:\n"
        f"base_model: {base_model}\n"
        # f"data_path: {data_path}\n"
        f"output_dir: {output_dir}\n"
        f"batch_size: {batch_size}\n"
        f"micro_batch_size: {micro_batch_size}\n"
        f"num_epochs: {num_epochs}\n"
        f"learning_rate: {learning_rate}\n"
        f"cutoff_len: {cutoff_len}\n"
        f"val_set_size: {val_set_size}\n"
        f"lora_r: {lora_r}\n"
        f"lora_alpha: {lora_alpha}\n"
        f"lora_dropout: {lora_dropout}\n"
        f"lora_target_modules: {lora_target_modules}\n"
        f"train_on_inputs: {train_on_inputs}\n"
        f"add_eos_token: {add_eos_token}\n"
        f"group_by_length: {group_by_length}\n"
        f"wandb_project: {wandb_project}\n"
        f"wandb_run_name: {wandb_run_name}\n"
        f"wandb_watch: {wandb_watch}\n"
        f"wandb_log_model: {wandb_log_model}\n"
        f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
        f"prompt template: {prompt_template_name}\n"
    )
assert (
    base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size

prompter = Prompter_TTS_Phonemes(prompt_template_name)

device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
    device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
    gradient_accumulation_steps = gradient_accumulation_steps // world_size

# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
    "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
    os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
    os.environ["WANDB_LOG_MODEL"] = wandb_log_model

Training Alpaca-LoRA model with params:
base_model: ../HuggingFace-Download-Accelerator/hf_local/models--baffo32--decapoda-research-llama-7B-hf/snapshots/aa18b48a1330572a6dd5f5d5619ed19838ca285c/
output_dir: ./result/lora-tts-1_phoenc
batch_size: 64
micro_batch_size: 8
num_epochs: 10
learning_rate: 0.0003
cutoff_len: 512
val_set_size: 2000
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'embed_tokens', 'lm_head']
train_on_inputs: True
add_eos_token: False
group_by_length: True
wandb_project: 
wandb_run_name: 
wandb_watch: 
wandb_log_model: 
resume_from_checkpoint: ./result/lora-tts-1_phoenc/checkpoint-5850/
prompt template: alpaca_tts_pho_enc



In [5]:
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.pad_token_id = (
    0  # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left"  # Allow batched inference
_ = tokenizer.add_tokens(["[SAUDIO]", "[EAUDIO]"], special_tokens=True)
new_tokens = [f"<{t}>" for t in range(encodec_dim*encodec_nq)]
_ = tokenizer.add_tokens(new_tokens)

f = open("phoneme_vocab.txt", "r")
phonemes = f.read()
_ = tokenizer.add_tokens(["[SPHONE]", "[EPHONE]"], special_tokens=True)
new_phonemes = [f"<{t}>" for t in  phonemes.split("\n")]
_ = tokenizer.add_tokens(new_phonemes)

model = LlamaForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)
torch.manual_seed(42)
model.resize_token_embeddings(len(tokenizer))
model = prepare_model_for_int8_training(model)

config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=lora_target_modules,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    # modules_to_save= ["base_model.model.model.embed_tokens.weight", "base_model.model.lm_head.weight"],
)
model = get_peft_model(model, config)
model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2

train_data = load_dataset("json", data_files=train_path)
test_data = load_dataset("json", data_files=test_path)

if resume_from_checkpoint:
    # Check the available weights and load them
    checkpoint_name = os.path.join(
        resume_from_checkpoint, "pytorch_model.bin"
    )  # Full checkpoint
    if not os.path.exists(checkpoint_name):
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "adapter_model.bin"
        )  # only LoRA model - LoRA config above has to fit
        resume_from_checkpoint = (
            False  # So the trainer won't try loading its state
        )
    # The two files above have a different name depending on how they were saved, but are actually the same.
    if os.path.exists(checkpoint_name):
        print(f"Restarting from {checkpoint_name}")
        adapters_weights = torch.load(checkpoint_name)
        set_peft_model_state_dict(model, adapters_weights)
    else:
        print(f"Checkpoint {checkpoint_name} not found")
        
model.eval()
train_data = train_data["train"].shuffle().map(generate_and_tokenize_prompt)
test_data = test_data["train"].shuffle().map(generate_and_tokenize_prompt)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Loading checkpoint shards: 100%|██████████| 33/33 [00:15<00:00,  2.14it/s]
You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding dimension will be 33111. This might induce some performance reduction as *Tensor Cores* will not be available. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/

Restarting from ./result/lora-tts-1_phoenc/checkpoint-5850/adapter_model.bin


Map: 100%|██████████| 500/500 [00:01<00:00, 376.45 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 393.16 examples/s]


In [6]:
data_point = test_data[1]
device = "cuda"
temperature = 1.0
top_p = 0.75
top_k = 100
num_beams = 4
max_new_tokens = 128

prompt = prompter.generate_prompt(
    # data_point["text"],
    data_point["phonemes"],
    # data_point["output"],
)
print(prompt)

inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

generation_config = GenerationConfig(
    temperature=temperature,
    top_p=top_p,
    top_k=top_k,
    num_beams=num_beams,
    # **kwargs,
)
generate_params = {
    "input_ids": input_ids,
    "generation_config": generation_config,
    "return_dict_in_generate": True,
    "output_scores": True,
    "max_new_tokens": max_new_tokens,
}

### Phonemes:
<PHN><DH><AH0><HH><OW1><L><K><AA2><N><V><ER0><S><EY1><SH><AH0><N><R><AE1><N><AA1><N><DH><AH0><B><R><EH1><K><F><AH0><S><T><W><IH1><CH><W><AH1><N><AH0><N><D><AO1><L><AH0><B><Y><UW1><Z><D><R><AW1><N><D><L><IY0></PHN>

### Response:



In [10]:
with torch.no_grad():
    generation_output = model.generate(
        input_ids=input_ids,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=max_new_tokens,
    )
s = generation_output.sequences[0]
output = tokenizer.decode(s)
print(output)

<unk>### Phonemes:
<PHN><DH><AH0><HH><OW1><L><K><AA2><N><V><ER0><S><EY1><SH><AH0><N><R><AE1><N><AA1><N><DH><AH0><B><R><EH1><K><F><AH0><S><T><W><IH1><CH><W><AH1><N><AH0><N><D><AO1><L><AH0><B><Y><UW1><Z><D><R><AW1><N><D><L><IY0></PHN>

### Response:
<SPCH> <121> <408> <408> <491> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <310> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724> <724>
