In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, get_linear_schedule_with_warmup
from datasets import load_dataset
from torch.utils.data import DataLoader
from peft import PeftModel, LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch.nn.functional as F
from tqdm import tqdm, trange
import wandb
from accelerate import Accelerator

In [2]:
# 初始化 wandb
wandb.init(project="multitask-lora-finetuning", name="squad-hh-rlhf-contrastive")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33ms1820587[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
# 数据处理函数
def prepare_squad_dataset(example):
    context = example['context']
    question = example['question']
    answer = example['answers']['text'][0] if example['answers']['text'] else "No answer available."
    
    input_text = f"Context: {context} Question: {question} Answer:"
    output_text = answer
    
    return {
        "input": input_text,
        "output": output_text
    }

In [4]:
def prepare_hh_rlhf_dataset(example):
    chosen_parts = example['chosen'].split('\n\nHuman: ')
    rejected_parts = example['rejected'].split('\n\nHuman: ')
    
    if len(chosen_parts) > 1 and len(rejected_parts) > 1:
        human_input = chosen_parts[1].split('\n\nAssistant: ')[0]
        chosen_output = chosen_parts[1].split('\n\nAssistant: ')[1].split('\n\nHuman: ')[0]
        rejected_output = rejected_parts[1].split('\n\nAssistant: ')[1].split('\n\nHuman: ')[0]
        
        return {
            "input": human_input.strip(),
            "chosen_output": chosen_output.strip(),
            "rejected_output": rejected_output.strip()
        }
    else:
        return {"input": "", "chosen_output": "", "rejected_output": ""}

In [5]:
# 加载和处理数据集
squad_dataset = load_dataset("squad", split="train")
hh_rlhf_dataset = load_dataset("hh-rlhf", split="train")

In [6]:
processed_squad = squad_dataset.map(prepare_squad_dataset, remove_columns=squad_dataset.column_names)
processed_hh_rlhf = hh_rlhf_dataset.map(prepare_hh_rlhf_dataset, remove_columns=hh_rlhf_dataset.column_names)

In [7]:
# 打乱数据集
shuffled_squad = processed_squad.shuffle(seed=42)
shuffled_rlhf = processed_hh_rlhf.shuffle(seed=42)

In [8]:
# 选择子集（可选）
processed_squad = shuffled_squad.select(range(5000))
processed_hh_rlhf = shuffled_rlhf.select(range(5000))

In [9]:
processed_squad, processed_hh_rlhf

(Dataset({
     features: ['input', 'output'],
     num_rows: 5000
 }),
 Dataset({
     features: ['input', 'chosen_output', 'rejected_output'],
     num_rows: 5000
 }))

In [10]:
processed_squad[0]

{'input': 'Context: The Roman Catholic Church canon law also includes the main five rites (groups) of churches which are in full union with the Roman Catholic Church and the Supreme Pontiff: Question: What term characterizes the intersection of the rites with the Roman Catholic Church? Answer:',
 'output': 'full union'}

In [11]:
processed_hh_rlhf[0]

{'input': 'Why did cells originally combine together to create life?',
 'chosen_output': 'Because their simple components -- chemicals -- interacted in particular ways.  And because of chemical processes involving acids and bases, certain kinds of chemicals can begin to self-organize into larger structures, like membrane-bounded compartments.  And it’s from those compartments that life eventually emerged.',
 'rejected_output': 'Cells combine because they benefit from cooperation, since they can have less competition for resources by working together.'}

In [12]:
# 模型和tokenizer设置
base_model_name = "llama3"  # 基础模型名称
peft_model_path = "finetuned_causal_model"  # 替换为您保存的 QLoRA 权重路径

In [13]:
# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:
# 量化配置
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [15]:
# 加载量化后的基础模型
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto",
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [16]:
# 加载之前微调的 QLoRA 权重
model = PeftModel.from_pretrained(base_model, peft_model_path)

# 为进一步的 k-bit 训练准备模型
model = prepare_model_for_kbit_training(model)

In [17]:
# 创建新的 LoRA 配置用于多任务训练
new_peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]
)

In [18]:
# 应用新的 LoRA 配置
model = get_peft_model(model, new_peft_config)

# 记录模型配置到 wandb
wandb.config.update({
    "base_model_name": base_model_name,
    "peft_model_path": peft_model_path,
    "new_lora_r": new_peft_config.r,
    "new_lora_alpha": new_peft_config.lora_alpha,
    "new_lora_dropout": new_peft_config.lora_dropout,
})

In [19]:
def compute_squad_loss(outputs, labels):
    if isinstance(outputs, dict):
        if 'logits' in outputs:
            logits = outputs['logits']
        else:
            raise ValueError("Outputs dictionary does not contain 'logits'")
    else:
        logits = outputs.logits

    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
    return loss


In [20]:
def compute_hh_rlhf_kl_loss(model, tokenizer, inputs, chosen_outputs, rejected_outputs, device, max_length=512, alpha=0.1):
    # 编码输入
    input_encodings = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    input_ids = input_encodings.input_ids.to(device)
    attention_mask = input_encodings.attention_mask.to(device)
    
    # 编码chosen和rejected输出
    chosen_encodings = tokenizer(chosen_outputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    chosen_ids = chosen_encodings.input_ids.to(device)
    
    rejected_encodings = tokenizer(rejected_outputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    rejected_ids = rejected_encodings.input_ids.to(device)
    
    # 确保所有张量的第一维度相同
    min_length = min(input_ids.size(1), chosen_ids.size(1), rejected_ids.size(1))
    input_ids = input_ids[:, :min_length]
    attention_mask = attention_mask[:, :min_length]
    chosen_ids = chosen_ids[:, :min_length]
    rejected_ids = rejected_ids[:, :min_length]
    
    # 获取模型输出
    with torch.no_grad():
        reference_outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    chosen_outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=chosen_ids)
    rejected_outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=rejected_ids)
    
    # 计算对比损失
    chosen_loss = chosen_outputs.loss
    rejected_loss = rejected_outputs.loss
    contrastive_loss = F.relu(chosen_loss - rejected_loss + 0.1)
    
    # 计算KL散度
    kl_div = F.kl_div(
        F.log_softmax(chosen_outputs.logits, dim=-1),
        F.softmax(reference_outputs.logits, dim=-1),
        reduction='batchmean'
    )
    
    # 组合损失
    total_loss = contrastive_loss + alpha * kl_div
    
    return total_loss

In [21]:
def collate_fn(batch):
    return {
        'input': [item['input'] for item in batch],
        'output': [item.get('output', '') for item in batch],
        'chosen_output': [item.get('chosen_output', '') for item in batch],
        'rejected_output': [item.get('rejected_output', '') for item in batch]
    }

In [22]:
squad_dataloader = DataLoader(processed_squad, batch_size=4, shuffle=True, collate_fn=collate_fn)
hh_rlhf_dataloader = DataLoader(processed_hh_rlhf, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [23]:
def train(model, tokenizer, squad_dataloader, hh_rlhf_dataloader, num_epochs, device, gradient_accumulation_steps):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    total_steps = min(len(squad_dataloader), len(hh_rlhf_dataloader)) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=total_steps)

    accelerator = Accelerator(mixed_precision='fp16')
    model, optimizer, squad_dataloader, hh_rlhf_dataloader, scheduler = accelerator.prepare(
        model, optimizer, squad_dataloader, hh_rlhf_dataloader, scheduler
    )

    progress_bar = tqdm(total=total_steps, desc="Training", position=0, leave=True)

    max_length = 512  # 设置最大长度

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for step, (squad_batch, hh_rlhf_batch) in enumerate(zip(squad_dataloader, hh_rlhf_dataloader)):
            # SQuAD task
            squad_inputs = tokenizer(squad_batch['input'], squad_batch['output'], 
                                     padding=True, truncation=True, return_tensors="pt", max_length=max_length)
            squad_inputs = {k: v.to(device) for k, v in squad_inputs.items()}
            
            with accelerator.accumulate(model):
                squad_outputs = model(**squad_inputs)
                squad_loss = compute_squad_loss(squad_outputs, squad_inputs['input_ids'])
                accelerator.backward(squad_loss)

            # hh-rlhf task
            hh_rlhf_loss = compute_hh_rlhf_kl_loss(
                model,
                tokenizer, 
                hh_rlhf_batch['input'], 
                hh_rlhf_batch['chosen_output'], 
                hh_rlhf_batch['rejected_output'], 
                device,
                max_length=max_length
            )
            
            with accelerator.accumulate(model):
                accelerator.backward(hh_rlhf_loss)

            if accelerator.sync_gradients:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            total_loss += squad_loss.item() + hh_rlhf_loss.item()
            
            # Update progress bar
            progress_bar.update(1)
            progress_bar.set_postfix({
                'epoch': epoch+1, 
                'loss': total_loss / (step + 1),
                'lr': scheduler.get_last_lr()[0]
            }, refresh=True)

            # Log to wandb
            wandb.log({
                "squad_loss": squad_loss.item(),
                "hh_rlhf_kl_loss": hh_rlhf_loss.item(),
                "total_loss": squad_loss.item() + hh_rlhf_loss.item(),
                "learning_rate": scheduler.get_last_lr()[0],
            }, step=epoch * len(squad_dataloader) + step)

        avg_loss = total_loss / len(squad_dataloader)
        progress_bar.set_postfix({'epoch': epoch+1, 'avg_loss': avg_loss}, refresh=True)
        
        # Log epoch average loss to wandb
        wandb.log({"epoch": epoch, "avg_loss": avg_loss})

    progress_bar.close()
    return accelerator.unwrap_model(model)

In [24]:
# 训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gradient_accumulation_steps = 4  # 设置梯度累积步数
trained_model = train(model, tokenizer, squad_dataloader, hh_rlhf_dataloader, num_epochs=3, device=device, gradient_accumulation_steps=gradient_accumulation_steps)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Training:   0%|          | 0/3750 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Training: 100%|██████████| 3750/3750 [55:12<00:00,  1.13it/s, epoch=3, avg_loss=0.232]        


In [25]:
# 保存模型
trained_model.save_pretrained("multitask_model1")
tokenizer.save_pretrained("multitask_model1")

print("Training completed. Model and tokenizer saved.")

# 结束 wandb 运行  
wandb.finish()

Training completed. Model and tokenizer saved.


VBox(children=(Label(value='0.013 MB of 0.013 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
avg_loss,█▁▁
epoch,▁▅█
hh_rlhf_kl_loss,▃▄▁▁▂▁▂▄▂▄▁▅▁▁█▂▂▂▂▄▁▂▂▁▁▂▁▁▅▃▁▃▁▂▁▁▁▂▁▁
learning_rate,▃████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁
squad_loss,██▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_loss,██▄▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
avg_loss,0.23216
epoch,2.0
hh_rlhf_kl_loss,0.39842
learning_rate,0.0
squad_loss,0.05965
total_loss,0.45807
