In [None]:
import copy
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

from tqdm import tqdm
from typing import List
from einops import rearrange
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
from lora_from_scratch import *
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
print(f'device:{device}\ndtype:{dtype}')

device:cuda
dtype:torch.bfloat16


模型选用 LiteLlama-460M-1T，数据集选用 vicgalle/alpaca-gpt4

In [None]:
# 模型和数据路径都可以改成本地的
model_name_or_path = 'ahxt/LiteLlama-460M-1T'
data_name_or_path = 'vicgalle/alpaca-gpt4'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype).to(device)

In [None]:
# 获取 lora model
replace_linear_with_lora(model, r=8, alpha=16, dropout_p=0.0)
model.to(device)

# 查看可训练参数
print_trainable_parameters(model)
     

In [None]:
# 定义训练数据集
class SFTDataset(Dataset):
    def __init__(self,
        tokenizer: AutoTokenizer,
        data_path: str,
        load_local: bool = False,
        max_len: int = 256,
        split_len: str = '1%',
    ):
        super().__init__()
        self.tokenizer = tokenizer

        if load_local:
            self.ds = load_dataset('json', data_dir=data_path, split=f'train[:{split_len}]')
        else:
            self.ds = load_dataset(data_path, split=f'train[:{split_len}]')
        self.max_len = max_len

        def process_func(example):
            # 提取 instruction 和 input
            instruction = example['instruction'].strip()
            input = example['input'].strip()
            output = example['output'].strip()

            # 构造模板
            instruction_prompt = f"Human: {instruction}\n" + \
                                    (f"{input}\n" if len(input) > 0 else "") + \
                                    "Assistant: "
            output_prompt = f"{output}\n"

            # 截断，最大不超过 max_len
            tokenized_instruction = self.tokenizer(instruction_prompt, add_special_tokens=False)['input_ids']
            tokenized_output = self.tokenizer(output_prompt, add_special_tokens=False)['input_ids']
            tokenized_prompt = (tokenized_instruction + tokenized_output)[:self.max_len]

            # 构造 input_ids, attention_mask, labels
            input_ids = tokenized_prompt[:-1]
            padding_mask = ([0] * len(tokenized_instruction) + [1] * (len(tokenized_output)))[:self.max_len][1:]
            labels = tokenized_prompt[1:]

            return {
                'input_ids': torch.LongTensor(input_ids),
                'attention_mask': torch.LongTensor(padding_mask),
                'labels': torch.LongTensor(labels),
            }

        self.ds = self.ds.map(
            process_func,
            batched=False,
            remove_columns=self.ds.column_names,
            desc='Processing dataset',
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index: int):
        return self.ds[index]
     

In [None]:
ds = SFTDataset(tokenizer, data_name_or_path, load_local=False)

In [None]:
print(len(ds[0]['input_ids']))
print(len(ds[0]['attention_mask']))
print(len(ds[0]['labels']))

In [None]:
def collate_fn(batch: List, tokenizer):
    max_len = max(len(item['input_ids']) for item in batch)

    input_ids = []
    attention_mask = []
    labels = []

    for item in batch:
        input_id = item['input_ids']
        attention_mask_item = item['attention_mask']
        label = item['labels']

        # 计算填充长度
        pad_len = max_len - len(input_id)

        # 左填充
        input_ids.append([tokenizer.eos_token_id] * pad_len + input_id)
        attention_mask.append([0] * pad_len + attention_mask_item)
        labels.append([tokenizer.eos_token_id] * pad_len + label)

    # 将列表转换为张量
    input_ids = torch.LongTensor(input_ids)
    attention_mask = torch.LongTensor(attention_mask)
    labels = torch.LongTensor(labels)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
    }

In [None]:
bsz = 16
lr = 1e-3
num_epochs = 10
logging_steps = 5
max_grad_norm = 1.0

In [None]:
dataloader = DataLoader(ds, batch_size=bsz, shuffle=True, collate_fn=lambda batch: collate_fn(batch, tokenizer))

In [None]:
for batch in dataloader:
    print(batch['input_ids'].shape)
    print(batch['attention_mask'].shape)
    print(batch['labels'].shape)
    break

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=lr)

In [None]:
model.train()

total_loss = 0
total_step = 0
for epoch in range(num_epochs):
    for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        logits = outputs.logits
        rearranged_logits = rearrange(logits, 'bsz seq_len vocab_size -> (bsz seq_len) vocab_size')
        rearranged_attention_mask = rearrange(attention_mask, 'bsz seq_len -> (bsz seq_len)')
        rearranged_labels = rearrange(labels, 'bsz seq_len -> (bsz seq_len)')

        sum_loss = F.cross_entropy(rearranged_logits, rearranged_labels, ignore_index=0, reduction='none')
        loss = torch.sum(sum_loss * rearranged_attention_mask) / torch.sum(rearranged_attention_mask)
        loss.backward()

        # 计算梯度范数并裁剪
        total_norm = nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()
        total_loss += loss.item()

        total_step += 1
        if total_step % logging_steps == 0:
            avg_loss = total_loss / total_step
            print(f"Step: {step+1}/{len(dataloader)}, Loss: {avg_loss:.4f}, Grad Norm: {total_norm:.4f}", flush=True)


    # 打印每个 epoch 结束的累计损失
    print(f"Epoch {epoch+1} finished, Average Loss: {total_loss / total_step:.4f}", flush=True)

In [None]:
def inference(
    model,
    tokenizer,
    text: str,
    max_new_tokens: int = 200,
    do_sample: bool = True,
    top_k: int = 40,
    temperature: float = 0.3,
):
    instruction_prompt = f"Human: {text}\nAssistant: "
    prompt = tokenizer(instruction_prompt, return_tensors='pt', add_special_tokens=False).to(device)
    outputs = model.generate(
        **prompt,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        top_k=top_k,
        temperature=temperature,
    )
    response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    return response

In [None]:
for test_text in [
    'Give three tips for staying healthy.',
    'What are the three primary colors?',
    'Describe the structure of an atom.',
]:
    print('=' * 80)
    print(inference(model, tokenizer, test_text))

# 2.SFT

模型选用 Qwen/Qwen1.5-0.5B，数据集选用 bio-nlp-umass/bioinstruct

In [None]:
# 模型和数据路径都可以改成本地的
model_name_or_path = 'Qwen/Qwen1.5-0.5B'
data_name_or_path = 'bio-nlp-umass/bioinstruct'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'

model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype).to(device)

In [None]:
# 获取 lora model
replace_linear_with_lora(model, r=8, alpha=16, dropout_p=0.0)
model.to(device)

# 查看可训练参数
print_trainable_parameters(model)

In [None]:
# 定义训练数据集
class SFTDataset(Dataset):
    def __init__(self,
        tokenizer: AutoTokenizer,
        data_path: str,
        load_local: bool = False,
        max_len: int = 256,
        split_len: str = '1%',
    ):
        super().__init__()
        self.tokenizer = tokenizer

        if load_local:
            ds = load_dataset('json', data_dir=data_path, split=f'train[:{split_len}]')
        else:
            ds = load_dataset(data_path, split=f'train[:{split_len}]')
        self.max_len = max_len

        def process_func(example):
            # 提取 instruction 和 input
            instruction = example['instruction'].strip()
            input = example['input'].strip()
            output = example['output'].strip()

            # 构造模板
            instruction_msg = [
                {"role": "user", "content": (instruction + f"\n{input}") if len(input) > 0 else instruction}
            ]
            tokenized_instruction = tokenizer.apply_chat_template(instruction_msg, tokenize=True, add_generation_prompt=True)
            tokenized_output = tokenizer(output + "<|im_end|>" + f"{tokenizer.eos_token}\n")['input_ids']

            # 截断，最大不超过 max_len
            tokenized_prompt = (tokenized_instruction + tokenized_output)[:self.max_len]

            # 构造 input_ids, attention_mask, labels
            input_ids = tokenized_prompt[:-1]
            padding_mask = ([0] * len(tokenized_instruction) + [1] * (len(tokenized_output)))[:self.max_len][1:]
            labels = tokenized_prompt[1:]

            return {
                'input_ids': input_ids,
                'attention_mask': padding_mask,
                'labels': labels,
            }

        self.ds = ds.map(
            process_func,
            batched=False,
            remove_columns=ds.column_names,
            desc='Processing dataset',
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index: int):
        return self.ds[index]

In [None]:
ds = SFTDataset(tokenizer, data_name_or_path, load_local=False)

In [None]:
print(len(ds[0]['input_ids']))
print(len(ds[0]['attention_mask']))
print(len(ds[0]['labels']))

print(tokenizer.decode(ds[0]['input_ids']))
print(ds[0]['attention_mask'])
print(tokenizer.decode(ds[0]['labels']))

In [None]:
def collate_fn(batch: List, tokenizer):
    max_len = max(len(item['input_ids']) for item in batch)

    input_ids = []
    attention_mask = []
    labels = []

    for item in batch:
        input_id = item['input_ids']
        attention_mask_item = item['attention_mask']
        label = item['labels']

        # 计算填充长度
        pad_len = max_len - len(input_id)

        # 左填充
        input_ids.append([tokenizer.eos_token_id] * pad_len + input_id)
        attention_mask.append([0] * pad_len + attention_mask_item)
        labels.append([tokenizer.eos_token_id] * pad_len + label)

    # 将列表转换为张量
    input_ids = torch.LongTensor(input_ids)
    attention_mask = torch.LongTensor(attention_mask)
    labels = torch.LongTensor(labels)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
    }

In [None]:
bsz = 8
lr = 5e-4
num_epochs = 3
logging_steps = 5
max_grad_norm = 1.0

In [None]:
dataloader = DataLoader(ds, batch_size=bsz, shuffle=True, collate_fn=lambda batch: collate_fn(batch, tokenizer))

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=lr)

In [None]:
model.train()

total_loss = 0
total_step = 0
for epoch in range(num_epochs):
    for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        logits = outputs.logits
        rearranged_logits = rearrange(logits, 'bsz seq_len vocab_size -> (bsz seq_len) vocab_size')
        rearranged_attention_mask = rearrange(attention_mask, 'bsz seq_len -> (bsz seq_len)')
        rearranged_labels = rearrange(labels, 'bsz seq_len -> (bsz seq_len)')

        sum_loss = F.cross_entropy(rearranged_logits, rearranged_labels, ignore_index=0, reduction='none')
        loss = torch.sum(sum_loss * rearranged_attention_mask) / torch.sum(rearranged_attention_mask)
        loss.backward()

        # 计算梯度范数并裁剪
        total_norm = nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()
        total_loss += loss.item()

        total_step += 1
        if total_step % logging_steps == 0:
            avg_loss = total_loss / total_step
            print(f"Step: {step+1}/{len(dataloader)}, Loss: {avg_loss:.4f}, Grad Norm: {total_norm:.4f}", flush=True)
            # print(f"Step: {step+1}/{len(dataloader)}, Loss: {avg_loss:.4f}", flush=True)


    # 打印每个 epoch 结束的累计损失
    print(f"Epoch {epoch+1} finished, Average Loss: {total_loss / total_step:.4f}", flush=True)

In [None]:
def inference(
    model,
    tokenizer,
    text: str,
    max_new_tokens: int = 160,
    do_sample: bool = True,
    temperature: float = 0.3,
    print_inputs: bool = True,
    streaming: bool = False,
):
    # 构建输入
    prompt_msg = [
        {"role": "user", "content": text}
    ]
    prompt = tokenizer.apply_chat_template(prompt_msg, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=False).to(device)
    input_ids = inputs['input_ids']
    im_end_id = tokenizer.encode("<|im_end|>")[0]

    # 是否打印输入部分
    if print_inputs:
        print(prompt, end='')

    # 生成
    stop_words = [tokenizer.eos_token_id, im_end_id]
    generated_tokens = []

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(input_ids)

        logits = outputs.logits[:, -1, :]

        # 不同采样方式
        if do_sample:
            logits = logits / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
        else:
            # 贪婪解码
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
        if next_token.item() in stop_words:
            break
        generated_tokens.append(next_token.item())
        # 流式输出
        if streaming:
            yield tokenizer.decode(generated_tokens)

        # 更新输入
        input_ids = torch.cat([input_ids, next_token], dim=-1)

    generated_text = tokenizer.decode(generated_tokens)
    return generated_text

In [None]:
model.eval()

for test_text in [
    'Describe the process of bacterial conjugation and its significance in the context of antibiotic resistance.',
    'Explain the role of insulin in the body and how insulin resistance affects blood sugar levels.',
    'Provide recommendations for lifestyle changes that can help improve the overall health of a patient with type 2 diabetes.',
]:
    print('=' * 80)
    last_text = ''
    for text in inference(model, tokenizer, test_text, streaming=True):
        cur_text = text.replace(last_text, '')
        print(cur_text, end='', flush=True)
        last_text = text
    print('\n')