In [1]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorWithPadding
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from peft import PeftModel, PeftConfig
import wandb
import os

In [2]:
# wandb 登录
wandb_api_key = "49b1a1cdc297defea9722d651839cd8621111183"  # 替换为你的 API 密钥
os.environ["WANDB_API_KEY"] = wandb_api_key
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33ms1820587[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
# 初始化 wandb
wandb.init(project="multitask_lora_finetuning", name="experiment_1")

In [4]:
# 设置设备
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# 加载数据集
def load_and_preprocess_datasets(num_samples=10000):
    imdb_dataset = load_dataset("imdb")
    agnews_dataset = load_dataset("ag_news")

    imdb_train_dataset = imdb_dataset["train"].shuffle(seed=42).select(range(num_samples))
    agnews_train_dataset = agnews_dataset["train"].shuffle(seed=42).select(range(num_samples))

    imdb_train_dataset = imdb_train_dataset.rename_column("label", "imdb_label")
    agnews_train_dataset = agnews_train_dataset.rename_column("label", "agnews_label")

    imdb_train_dataset = imdb_train_dataset.add_column("task_type", ["imdb"] * len(imdb_train_dataset))
    agnews_train_dataset = agnews_train_dataset.add_column("task_type", ["agnews"] * len(agnews_train_dataset))

    return imdb_train_dataset, agnews_train_dataset

imdb_train_dataset, agnews_train_dataset = load_and_preprocess_datasets()

In [6]:
# 数据预处理函数
def preprocess_function(examples, tokenizer):
    if 'text' in examples:
        encodings = tokenizer(examples['text'], truncation=True, max_length=512)
    else:
        encodings = {'input_ids': [[]] * len(examples['task_type']), 'attention_mask': [[]] * len(examples['task_type'])}
    
    task_type_map = {'imdb': 0, 'agnews': 1}
    encodings['task_type'] = [task_type_map[task_type] for task_type in examples['task_type']]
    
    encodings.update({
        'imdb_label': examples['imdb_label'] if 'imdb_label' in examples else [None] * len(examples['task_type']),
        'agnews_label': examples['agnews_label'] if 'agnews_label' in examples else [None] * len(examples['task_type']),
    })
    
    return encodings

In [7]:
# 设置模型和tokenizer
model_name = 'finetuned_lora_model'  # 替换为你的预训练模型路径
tokenizer = AutoTokenizer.from_pretrained('llama3')
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
# 设置QLoRA配置
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [9]:
# 加载预训练的 PeftModel
peft_config = PeftConfig.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    peft_config.base_model_name_or_path,
    quantization_config=bnb_config,
    device_map="auto"
)
model = PeftModel.from_pretrained(model, model_name)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [10]:
# 解冻LoRA参数
for name, param in model.named_parameters():
    if 'lora' in name or 'adapter' in name:
        param.requires_grad = True

In [11]:
# 数据预处理
imdb_train_dataset = imdb_train_dataset.map(
    lambda examples: preprocess_function(examples, tokenizer),
    batched=True, remove_columns=['text']
)
agnews_train_dataset = agnews_train_dataset.map(
    lambda examples: preprocess_function(examples, tokenizer),
    batched=True, remove_columns=['text']
)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [12]:
# 合并数据集
train_dataset = concatenate_datasets([imdb_train_dataset, agnews_train_dataset])

# 设置数据格式
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'task_type', 'imdb_label', 'agnews_label'])

# 创建 DataCollator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [13]:
# 创建 DataLoader
batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)

In [14]:
# 定义损失函数和优化器
imdb_criterion = nn.CrossEntropyLoss()
agnews_criterion = nn.CrossEntropyLoss()
learning_rate = 2e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [16]:
# 设置混合精度训练
scaler = torch.cuda.amp.GradScaler()

# 训练循环
num_epochs = 3
accumulation_steps = 4  # 梯度累积步数
total_steps = num_epochs * len(train_dataloader)

In [17]:
# 配置 wandb
wandb.config.update({
    "learning_rate": learning_rate,
    "epochs": num_epochs,
    "batch_size": batch_size,
    "accumulation_steps": accumulation_steps,
    "model_name": model_name,
})

In [18]:
with tqdm(total=total_steps, desc="Training") as pbar:
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        imdb_epoch_loss = 0.0
        agnews_epoch_loss = 0.0
        
        for i, batch in enumerate(train_dataloader):
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            task_type = batch['task_type'].to(model.device)
            
            imdb_indices = [i for i, task in enumerate(task_type) if task == 0]
            agnews_indices = [i for i, task in enumerate(task_type) if task == 1]
            
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):  # 使用bfloat16混合精度
                if len(imdb_indices) > 0:
                    imdb_outputs = model(input_ids[imdb_indices], attention_mask[imdb_indices])
                    imdb_logits = imdb_outputs.logits[:, -1, :].float()  # 确保是float32类型
                    imdb_labels = batch['imdb_label'][imdb_indices].to(model.device).long()  # 确保是long类型
                    imdb_loss = imdb_criterion(imdb_logits, imdb_labels)
                else:
                    imdb_loss = torch.tensor(0.0, device=model.device)

                if len(agnews_indices) > 0:
                    agnews_outputs = model(input_ids[agnews_indices], attention_mask[agnews_indices])
                    agnews_logits = agnews_outputs.logits[:, -1, :].float()  # 确保是float32类型
                    agnews_labels = batch['agnews_label'][agnews_indices].to(model.device).long()  # 确保是long类型
                    agnews_loss = agnews_criterion(agnews_logits, agnews_labels)
                else:
                    agnews_loss = torch.tensor(0.0, device=model.device)

                loss = (imdb_loss + agnews_loss) / accumulation_steps

            scaler.scale(loss).backward()
            
            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_dataloader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            epoch_loss += loss.item() * accumulation_steps
            imdb_epoch_loss += imdb_loss.item()
            agnews_epoch_loss += agnews_loss.item()
            
            pbar.update(1)
            
            # Log to wandb
            wandb.log({
                "batch_loss": loss.item(),
                "batch_imdb_loss": imdb_loss.item(),
                "batch_agnews_loss": agnews_loss.item(),
            })
        
        avg_epoch_loss = epoch_loss / len(train_dataloader)
        avg_imdb_loss = imdb_epoch_loss / len(train_dataloader)
        avg_agnews_loss = agnews_epoch_loss / len(train_dataloader)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_epoch_loss:.4f}, IMDB Loss: {avg_imdb_loss:.4f}, AGNews Loss: {avg_agnews_loss:.4f}")
        
        # Log epoch metrics to wandb
        wandb.log({
            "epoch": epoch + 1,
            "epoch_loss": avg_epoch_loss,
            "epoch_imdb_loss": avg_imdb_loss,
            "epoch_agnews_loss": avg_agnews_loss,
        })

Training:  33%|███▎      | 2500/7500 [48:16<1:41:53,  1.22s/it]

Epoch [1/3], Loss: 2.4884, IMDB Loss: 1.2708, AGNews Loss: 1.2176


Training:  67%|██████▋   | 5000/7500 [1:36:34<52:14,  1.25s/it]  

Epoch [2/3], Loss: 0.3987, IMDB Loss: 0.1247, AGNews Loss: 0.2740


Training: 100%|██████████| 7500/7500 [2:24:51<00:00,  1.16s/it]

Epoch [3/3], Loss: 0.3290, IMDB Loss: 0.1093, AGNews Loss: 0.2196





In [19]:
# 保存模型
model.save_pretrained("trained_model")
print("Model saved to trained_model")

# 结束 wandb 运行
wandb.finish()

Model saved to trained_model


VBox(children=(Label(value='0.013 MB of 0.013 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch_agnews_loss,█▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▂▁▂
batch_imdb_loss,█▄▅▂▂▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch_loss,█▅▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▅█
epoch_agnews_loss,█▁▁
epoch_imdb_loss,█▁▁
epoch_loss,█▁▁

0,1
batch_agnews_loss,0.09395
batch_imdb_loss,0.72328
batch_loss,0.20431
epoch,3.0
epoch_agnews_loss,0.21962
epoch_imdb_loss,0.10933
epoch_loss,0.32895
