Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient checkpointing failed in xla_device #5766

Closed
bmeoaountiful opened this issue Nov 3, 2023 · 4 comments
Closed

gradient checkpointing failed in xla_device #5766

bmeoaountiful opened this issue Nov 3, 2023 · 4 comments

Comments

@bmeoaountiful
Copy link

bmeoaountiful commented Nov 3, 2023

❓ Questions and Help

I try to fine-tune a large language model on xla_device, these models come from huggingface. The error is reported

Traceback (most recent call last):
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/runtime.py", line 85, in wrapper
    return fn(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 75, in _run_thread_per_device
    replica_results = list(
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/thread.py", line 57, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 68, in _thread_fn
    return fn()
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 184, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/projs/framework/root/llm2/fine-tune/fine-tune.py", line 196, in _mp_fn
    train()
  File "/projs/framework/root/llm2/fine-tune/fine-tune.py", line 193, in train
    train_loop_fn(train_device_loader, epoch)
  File "/projs/framework/root/llm2/fine-tune/fine-tune.py", line 184, in train_loop_fn
    output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/root/.cache/huggingface/modules/transformers_modules/modeling_llm.py", line 692, in forward
    outputs = self.model(
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/root/.cache/huggingface/modules/transformers_modules/modeling_llm.py", line 459, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 457, in checkpoint
    next(gen)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 1157, in _checkpoint_without_reentrant_generator
    device_module = _get_device_module(device)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 67, in _get_device_module
    device_module = getattr(torch, device)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/__init__.py", line 1833, in __getattr__
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'xla'

The function torch.utils.checkpoint.checkpoint decorated with _disable_dynamo.
https://github.com/pytorch/pytorch/blob/0d95378341b4eb19849295c7ccab08cc9be328a7/torch/utils/checkpoint.py#L341
Does this mean that if the model's device is set to xla, then torch.utils.checkpoint.checkpoint cannot be used?
If so, are there any alternative approaches to avoid using gradient checkpoint in LLM?
Any help on this would be greatly appreciated!

fine-tune.py

import os
import math
import pathlib
from typing import Optional, Dict
from dataclasses import dataclass, field
import json

import torch
from torch.utils.data import Dataset
import transformers
from transformers.training_args import TrainingArguments

import torch_xla
from torch_xla import runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=False)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path,
        tokenizer,
        model_max_length,
        user_tokens=[195],
        assistant_tokens=[196],
    ):
        super(SupervisedDataset, self).__init__()
        self.data = json.load(open(data_path))
        self.tokenizer = tokenizer
        self.model_max_length = model_max_length
        self.user_tokens = user_tokens
        self.assistant_tokens = assistant_tokens
        self.ignore_index = -100
        item = self.preprocessing(self.data[0])
        labels = []
        for id_ in item["labels"]:
            if id_ == -100:
                continue

            labels.append(id_)

    def __len__(self):
        return len(self.data)

    def preprocessing(self, example):
        input_ids = []
        labels = []

        for message in example["conversations"]:
            from_ = message["from"]
            value = message["value"]
            value_ids = self.tokenizer.encode(value)

            if from_ == "human":
                input_ids += self.user_tokens + value_ids
                labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(
                    value_ids
                )
            else:
                input_ids += self.assistant_tokens + value_ids
                labels += [self.ignore_index] + value_ids
        input_ids.append(self.tokenizer.eos_token_id)
        labels.append(self.tokenizer.eos_token_id)
        input_ids = input_ids[: self.model_max_length]
        labels = labels[: self.model_max_length]
        input_ids += [self.tokenizer.pad_token_id] * (
            self.model_max_length - len(input_ids)
        )
        labels += [self.ignore_index] * (self.model_max_length - len(labels))
        input_ids = torch.LongTensor(input_ids)
        labels = torch.LongTensor(labels)
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return self.preprocessing(self.data[idx])

def train():
    device = xm.xla_device()
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
        cache_dir=training_args.cache_dir,
    ).to(device)

    print('model device:', model.device)
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
        use_fast=False,
        model_max_length=training_args.model_max_length,
        cache_dir=training_args.cache_dir,
    )
#    if training_args.use_lora:
#        from peft import LoraConfig, TaskType, get_peft_model
#
#        peft_config = LoraConfig(
#            task_type=TaskType.CAUSAL_LM,
#            target_modules=["W_pack"],
#            inference_mode=False,
#            r=1,
#            lora_alpha=32,
#            lora_dropout=0.1,
#        )
#        model.enable_input_require_grads()
#        model = get_peft_model(model, peft_config)
#        model.print_trainable_parameters()

    dataset = SupervisedDataset(
        data_args.data_path, tokenizer, training_args.model_max_length
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=training_args.per_device_train_batch_size,
        shuffle=True,
        num_workers=8,
        persistent_workers=False,
        prefetch_factor=16)

    torch.manual_seed(training_args.seed)
    optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate,
                                  betas=(training_args.adam_beta1, training_args.adam_beta2),
                                  eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay)

    train_device_loader = pl.MpDeviceLoader(
        data_loader,
        device,
        loader_prefetch_size=8,
        device_prefetch_size=4,
        host_to_device_transfer_threads=8)

    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, batch in enumerate(loader):
          input_ids = batch["input_ids"]
          attention_mask = batch["attention_mask"]
          labels = batch["labels"]
          with xp.StepTrace('train_baichuan2-13b-chat'):
            with xp.Trace('build_graph'):
              optimizer.zero_grad()
              output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
              loss = output.loss
              loss.backward()
              xm.optimizer_step(optimizer)
              tracker.add(training_args.per_device_train_batch_size)

    for epoch in range(1, int(training_args.num_train_epochs)):
      train_loop_fn(train_device_loader, epoch)

def _mp_fn(index):
    train()


if __name__ == "__main__":
    xmp.spawn(_mp_fn, args=(), nprocs=None)

fine-tune.py can be run by executing the command

GPU_NUM_DEVICES=1 PJRT_DEVICE=CUDA python3 fine-tune.py  \
    --report_to "none" \
    --data_path "Baichuan2-13B-Chat/data/belle_chat_ramdon_10k.json" \
    --model_name_or_path "Baichuan2-13B-Chat/models--baichuan-inc--Baichuan2-13B-Chat/snapshots/8f6e343d545c503b91429582231d1d354dac2740/" \
    --output_dir "output" \
    --model_max_length 512 \
    --num_train_epochs 4 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --save_strategy epoch \
    --learning_rate 2e-5 \
    --lr_scheduler_type constant \
    --adam_beta1 0.9 \
    --adam_beta2 0.98 \
    --adam_epsilon 1e-8 \
    --max_grad_norm 1.0 \
    --weight_decay 1e-4 \
    --warmup_ratio 0.0 \
    --logging_steps 1 \
    --gradient_checkpointing False \
    --bf16 False \
    --tf32 False

dataset: https://github.com/baichuan-inc/Baichuan2/tree/main/fine-tune/data
model: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/tree/main
torch version: Version: 2.1.0
torch-xla:2.1.0

@bmeoaountiful bmeoaountiful changed the title _disable_dynamo torch.utils.checkpoint.checkpoint failed in xla_device Nov 3, 2023
@bmeoaountiful bmeoaountiful changed the title torch.utils.checkpoint.checkpoint failed in xla_device gradient checkpointing failed in xla_device Nov 3, 2023
@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 3, 2023

we used https://github.com/pytorch/xla/blob/master/torch_xla/utils/checkpoint.py but it is pretty much copied from upstream and add the optimization_barrier that prevent compiler CSE to remove the gradient checkpoint. If upstream gradient checkpoint code has the _disable_dynamo there must be a reason and I expect our code to run into the same issue.

@bmeoaountiful
Copy link
Author

it works, convert torch.utils.checkpoint.checkpoint to torch_xla.utils.checkpoint.checkpoint

@jeffhataws
Copy link
Collaborator

@JackCaoG, here the work-around was to switch from torch.utils.checkpoint.checkpoint to torch_xla.utils.checkpoint.checkpoint. However, it would be better to restore the ability to use torch.utils.checkpoint.checkpoint which was working in 1.13. What are your thoughts?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 7, 2024

IMO torch.utils.checkpoint.checkpoint won't do anything because XLA compiler's CSE(common-subexpression elimination) pass will undo the gradident checkpointing. We should somehow upstream our changes so user can use torch.utils.checkpoint.checkpoint properly through, we currently don't have resource for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants