In [9]:
# 代码参考了 https://github.com/facebookresearch/MetaICL

In [10]:
# %pip install numpy
# %pip install pickle
# %pip install transformers==4.28.1
# %pip install torch==2.1.2

In [11]:
import json
import os
import numpy as np
import pickle as pkl
import torch
from transformers import AutoTokenizer
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoModelForCausalLM
from transformers import Adafactor, AdamW, get_linear_schedule_with_warmup

In [12]:
!pwd

/nas/shared/ADLab_Oasim/gaoyufei/ICL


# 数据准备

In [13]:
# 下载数据并解压到同级目录
# https://github.com/gyfffffff/llm-deploy/releases/download/icl_data/icl_data.zip

# 使用Linux 系统的同学也可以直接运行：
# !wget https://github.com/gyfffffff/llm-deploy/releases/download/icl_data/icl_data.zip

# !unzip icl_data.zip -d data/

In [14]:
# 读取数据
train_data_files = ["data/icl_data/dream/dream_16_13_dev.jsonl", "data/icl_data/wiki_qa/wiki_qa_16_13_dev.jsonl"]
train_data = []
for train_data_file in train_data_files:
    with open(train_data_file, "r") as f:
        data = []
        for line in f:
            data.append(json.loads(line))
    train_data.append(data)
train_data

[[{'task': 'dream',
   'input': "Based on the conversation, what is the most probable relationship between the speakers? [SEP] Heather: Ron, what are you doing? Ron: Ah, nothing. I'm just looking up some information on the Internet. Heather: Like what? Let me see. Ron: No, no, it's okay. I mean, you know ... Heather: Baldness? What are you looking that up for? [Well, you know ... ] I ... I mean, you're not that bad off. Ron: Ah, there you go. Bringing it up again! Heather: No. I mean it. You look great! Honestly, it's not that bad. Ron: Hey, I get enough of it from friends, and the people at work, and now from you! Heather: Well, maybe you could wear a toupee? I think you'd look great. Ron: Oh no. And have it slip off my head on to my date's dinner plate as I lean over to kiss her? Uh-uh. Heather: Well, have you ever thought about seeking medical advice? There are new advances in medicines that not only retard hair loss, but help regenerate new growth. Ron: Ah, I still don't give much 

In [15]:
# 下载模型 (这里教师模型和学生模型都用GPT-2)

# windows
# %pip install -U "huggingface-hub[cli]"
# !$env:HF_ENDPOINT = "https://hf-mirror.com"
# !huggingface-cli download --resume-download openai-community/gpt2 --local-dir ../../models/GPT-2

# linux
# %pip install -U "huggingface-hub[cli]"
# !export HF_ENDPOINT=https://hf-mirror.com
# !huggingface-cli download --resume-download gpt2 --local-dir ../models/GPT-2

## 数据处理

In [16]:
max_length=1024
max_length_per_example=256
model_path = "../models/GPT-2"
batch_size = 4

tokenizer = AutoTokenizer.from_pretrained(model_path)

# 定义数据处理函数
def prepro_sentence_pair_single(ids1, ids2, max_length,
                                bos_token_id, eos_token_id,
                                allow_truncation=False):

    #if bos_token_id is not None:
    #    ids1 = [bos_token_id] + ids1
    #if eos_token_id is not None:
    #    ids2 = ids2 + [eos_token_id]
    if allow_truncation and len(ids1)+len(ids2) > max_length:
        ids1 = ids1[len(ids1)+len(ids2)-max_length:] # len = max_length-len(ids2)
        assert len(ids1)+len(ids2)==max_length

    n_mask = max_length-len(ids1)-len(ids2)
    assert n_mask>=0, (max_length, len(ids1), len(ids2))
    input_ids = ids1+ids2+[0 for _ in range(n_mask)]
    attention_mask = [1 for _ in ids1+ids2] + [0 for _ in range(n_mask)]
    token_type_ids = [0 for _ in ids1] + [1 for _ in ids2] + [0 for _ in range(n_mask)]
    return input_ids, attention_mask, token_type_ids


def _prepro_each_datapoint(dp, is_first=True, is_training=False, for_demonstrations=False):
    dp = dp.copy()

    no_label = np.all([option=="" for option in dp["options"]])
    no_input = dp["input"]==""
    if not is_first:
        dp["output"] = "\n\n\n" + dp["output"]
        if "options" in dp:
            dp["options"] = ["\n\n\n" + opt for opt in dp["options"]]
    if not no_input:
        if not no_label:
            dp["input"] = "\n" + dp["input"]

    input_tokens = tokenizer(dp["input"])["input_ids"]

    if is_training or for_demonstrations:
        output_tokens = tokenizer(dp["output"])["input_ids"]

        if "task" in dp:
            if len(input_tokens)>=max_length_per_example - 2 - len(output_tokens):
                if dp["task"].startswith("inst:") and len(input_tokens)<len(output_tokens):
                    output_tokens = output_tokens[:max_length_per_example - 2 - len(input_tokens)]
                else:
                    input_tokens = input_tokens[:max_length_per_example - 2 - len(output_tokens)]

        assert len(input_tokens)+len(output_tokens)+2<=max_length_per_example, \
            (dp.get("task", None), len(input_tokens), len(output_tokens), max_length_per_example)

        return output_tokens, input_tokens


    else:
        assert len(dp["options"])>=2, dp
        assert dp["output"] in dp["options"]
        option_tokens = [tokenizer(option)["input_ids"] for option in dp["options"]]
        option_length = np.max([len(option) for option in option_tokens])

        if len(input_tokens)>=max_length_per_example - 2 - option_length:
            input_tokens = input_tokens[:max_length_per_example - 2 - option_length]

        input_tokens = [input_tokens for _ in option_tokens]
        output_tokens = option_tokens
        option_tokens = [dp["options"].index(dp["output"])]

        return output_tokens, input_tokens, option_tokens


def _tensorize_for_training(train_data):
    for dp in train_data:  # train_data： [{"input": str, "output": str}, ...]
        assert type(dp)==dict, ("Each example should be a dictionary", dp)
        assert "input" in dp and "output" in dp, ("Training example should contain input and output", dp)

    # each datapoint: passage, question, options, output
    bos_token_id = tokenizer.bos_token_id
    eos_token_id = tokenizer.eos_token_id

    input_ids, attention_mask, token_type_ids = [], [], []

    for dp in train_data:
        inputs, outputs = _prepro_each_datapoint(
            dp, is_first=True, is_training=True)

        encoded = prepro_sentence_pair_single(
            inputs, outputs, max_length, bos_token_id, eos_token_id)

        input_ids.append(encoded[0])
        attention_mask.append(encoded[1])
        token_type_ids.append(encoded[2])

    return dict(input_ids=torch.LongTensor(input_ids),
                attention_mask=torch.LongTensor(attention_mask),
                token_type_ids=torch.LongTensor(token_type_ids))

In [17]:
# 数据转为tensor

# 数据格式：
# [[{}, {}, ...], [{}, {}, ...], ...]  
# 每个子列表是一个数据集，每个字典是一个数据点


def tensorize_for_training(train_data, is_training=True):
    inputs = {"input_ids": [], "attention_mask": [], "token_type_ids": []}
    
    # 张量化每一条数据
    for in_ in train_data:
        out = _tensorize_for_training(in_)
        for key in ["input_ids", "attention_mask", "token_type_ids"]:
            inputs[key] += out[key].numpy().tolist()

    # 数据打乱
    N = len(inputs["input_ids"])
    indices = np.random.permutation(range(N))
    for k, v in inputs.items():
        inputs[k] = np.array(v)[indices]

    # 保存数据
    with open("data/preprocessed_data.kpl", "wb") as f:
        pkl.dump({k:v for k, v in inputs.items()}, f)
    print("Finish saving preprocessed data ...")

    # 定义dataset，dataloader
    for k, v in inputs.items():
        inputs[k] = torch.LongTensor(v)
    shape = inputs["input_ids"].shape
    for v in inputs.values():
        assert v.shape==shape
    if "labels" in inputs:
        dataset = TensorDataset(inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"], inputs["labels"])
    else:
        dataset = TensorDataset(inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"])
    if is_training:
        sampler=RandomSampler(dataset)
    else:
        sampler=SequentialSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    return dataloader
dataloader = tensorize_for_training(train_data)


Finish saving preprocessed data ...


In [18]:
# 加载wikitext数据集
from datasets import load_dataset
tokenizer.pad_token = tokenizer.eos_token 
dataset = load_dataset("./data/wikitext", 'wikitext-103-raw-v1', split='validation')

def encode(examples):
    return tokenizer(examples["text"], truncation=True, max_length=max_length, padding="max_length")

dataset = dataset.map(encode, batched=True)

dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "text"])
text_dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)

In [19]:
for batch in text_dataloader:
    print(batch)
    break

{'text': ['', ' = Homarus gammarus = \n', '', ' Homarus gammarus , known as the European lobster or common lobster , is a species of clawed lobster from the eastern Atlantic Ocean , Mediterranean Sea and parts of the Black Sea . It is closely related to the American lobster , H. americanus . It may grow to a length of 60 cm ( 24 in ) and a mass of 6 kilograms ( 13 lb ) , and bears a conspicuous pair of claws . In life , the lobsters are blue , only becoming " lobster red " on cooking . Mating occurs in the summer , producing eggs which are carried by the females for up to a year before hatching into planktonic larvae . Homarus gammarus is a highly esteemed food , and is widely caught using lobster pots , mostly around the British Isles . \n'], 'input_ids': tensor([[50256, 50256, 50256,  ..., 50256, 50256, 50256],
        [  796,  8074, 20272,  ..., 50256, 50256, 50256],
        [50256, 50256, 50256,  ..., 50256, 50256, 50256],
        [ 8074, 20272,  9106,  ..., 50256, 50256, 50256]]),

# 训练模型

In [20]:
import logging

# 指定日志路径

out_dir = "output"
log_file = f"{out_dir}/log.txt"
handlers = [logging.StreamHandler()]
handlers.append(logging.FileHandler(log_file))
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO,
                    handlers=handlers)
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
logger.info(out_dir)

os.makedirs(out_dir, exist_ok=True)


12/21/2024 09:53:10 - INFO - __main__ - output


In [23]:
# 定义不同的损失函数
device = torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(model_path)
model.to(device)
def run_model_meta_icl(input_ids, attention_mask, token_type_ids, labels=None):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits[..., :-1, :].contiguous()

        if labels is None:
            labels = input_ids
        labels = labels[..., 1:].contiguous()
        label_mask = token_type_ids[..., 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        losses = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) # [batch_size, length]

        losses = losses.view(logits.size(0), logits.size(1)) * label_mask
        return torch.sum(losses, axis=1) / torch.sum(label_mask, axis=1)

def run_model_icl_distill(input_ids, attention_mask, token_type_ids, text_input_ids, text_attention_mask, step, labels=None):
        beta = 0.2
        stu_outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        stu_logits = stu_outputs.logits[..., :-1, :].contiguous()
        prob_stu = torch.nn.functional.softmax(stu_logits, dim=-1)

        tea_outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        tea_logits = tea_outputs.logits[..., :-1, :].contiguous()
        log_tea_prob = torch.nn.functional.log_softmax(tea_logits, dim=-1)

        # soft icl loss: 学生和教师输出的交叉熵
        soft_icl_loss = -torch.sum(prob_stu * log_tea_prob, axis=-1)
        # soft_loss_fct = torch.nn.CrossEntropyLoss()
        # print(stu_logits.view(-1, stu_logits.size(-1)).shape, soft_targets.view(-1).shape)
        # soft_icl_loss = soft_loss_fct(stu_logits.view(-1, stu_logits.size(-1)), soft_targets.view(-1))
        
        # soft_losses = soft_loss_fct(stu_logits.view(-1, stu_logits.size(-1)), soft_targets.view(-1).argmax(-1))

        # soft_losses = soft_losses.view(stu_logits.size(0), stu_logits.size(1))
        # soft_losses = torch.sum(soft_losses, axis=1)

        # soft lm loss: 学生和教师输出的交叉熵
        stu_text_outputs = model(input_ids=text_input_ids, attention_mask=text_attention_mask)
        stu_text_logits = stu_text_outputs.logits[..., :-1, :].contiguous()
        prob_stu_text = torch.nn.functional.softmax(stu_text_logits, dim=-1)
        
        tea_text_outputs = model(input_ids=text_input_ids, attention_mask=text_attention_mask)
        tea_text_logits = tea_text_outputs.logits[..., :-1, :].contiguous()
        log_tea_prob_text = torch.nn.functional.softmax(tea_text_logits, dim=-1)

        soft_lm_loss = -torch.sum(prob_stu_text * log_tea_prob_text, axis=-1)
    
        # soft_text_loss_fct = torch.nn.CrossEntropyLoss()
        # soft_text_losses = soft_text_loss_fct(stu_text_logits.view(-1, stu_text_logits.size(-1)), soft_text_targets.view(-1))

        # soft_text_losses = soft_text_losses.view(stu_text_logits.size(0), stu_text_logits.size(1))

        soft_loss = soft_icl_loss + beta * soft_lm_loss

        # hard icl loss: 学生和ground-truth的交叉熵
        if labels is None:
            labels = input_ids
        labels = labels[..., 1:].contiguous()
        label_mask = token_type_ids[..., 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        losses = loss_fct(stu_logits.view(-1, stu_logits.size(-1)), labels.view(-1)) # [batch_size, length]

        hard_icl_losses = losses.view(stu_logits.size(0), stu_logits.size(1)) * label_mask
        hard_icl_loss = torch.sum(hard_icl_losses, axis=1) / torch.sum(label_mask, axis=1)

        # hard lm loss: 学生和ground-truth的交叉熵
        text_outputs = model(input_ids=text_input_ids, attention_mask=text_attention_mask)
        labels_text = input_ids[..., 1:].contiguous()
        logits_text = text_outputs.logits[..., :-1, :].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss()
        hard_lm_loss = loss_fct(logits_text.view(-1, logits_text.size(-1)), labels_text.view(-1))
        
        hard_loss = hard_icl_loss + beta * hard_lm_loss

        # "We linearly decrease the weight of hard-label loss α(t) and linearly increase the weight of soft-label loss during training."
        alpha = 1-step/len(dataloader)

        loss = alpha * torch.sum(hard_loss) + (1-alpha) * soft_loss

        return loss

In [24]:
# 训练
save_period = 10
log_period = 5
batch_size = 4
num_training_steps = 1000
epoches = 100

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=1e-05, eps=1e-08)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)   

model.train()

def save(step):
        model_state_dict = {key[7:] if key.startswith("module.") else key: value.cpu()
                        for key, value in model.state_dict().items()}
        torch.save(model_state_dict, os.path.join(out_dir, "model-{}.pt".format(step)))
        logger.info("Saving model parameters at step=%d" % step)

def do_train(num_training_steps, save_period, log_period, gradient_accumulation_steps=1, max_grad_norm=1.0):
        global_step = 0
        train_losses = []
        dataloader2 = list(text_dataloader)
        for epoch in range(epoches):
                for step, batch in enumerate(dataloader):
                        global_step += 1

                        input_ids=batch[0].to(device)
                        attention_mask=batch[1].to(device)
                        token_type_ids=batch[2].to(device)
                        if len(batch)==3:
                                labels=None
                        else:
                                labels=batch[3].to(device)
                        text_input_ids = dataloader2[global_step%len(dataloader2)]["input_ids"].to(device)
                        text_attention_mask = dataloader2[global_step%len(dataloader2)]["attention_mask"].to(device)
                        # loss = run_model_meta_icl(input_ids, attention_mask, token_type_ids=token_type_ids, labels=labels)
                        loss = run_model_icl_distill(input_ids, attention_mask, token_type_ids=token_type_ids, labels=labels, step=global_step, text_input_ids=text_input_ids, text_attention_mask=text_attention_mask)
                        loss = loss.mean()

                        train_losses.append(loss.detach().cpu())

                        loss.backward()

                        if global_step % gradient_accumulation_steps == 0:
                                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

                                optimizer.step()    # We have accumulated enought gradients
                                if scheduler is not None:
                                        scheduler.step()
                                        model.zero_grad()

                        if global_step % log_period == 0:
                                logger.info("global step %d\t train loss %.2f" % (global_step, np.mean(train_losses)))
                                train_losses = []

                        if global_step % save_period == 0:
                                save(global_step)

                        if global_step==num_training_steps:
                                break

                if global_step==num_training_steps:
                        break

        logger.info("Finish training")
do_train(data, batch_size, num_training_steps, save_period, log_period)

12/21/2024 10:02:59 - INFO - __main__ - Saving model parameters at step=4
