Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

does opendelta support gradient_checkpointing? #39

Closed
hmzo opened this issue Oct 27, 2022 · 3 comments
Closed

does opendelta support gradient_checkpointing? #39

hmzo opened this issue Oct 27, 2022 · 3 comments

Comments

@hmzo
Copy link

hmzo commented Oct 27, 2022

Thank you for the awesome work.
I met some problems when using opendelta with gradient_checkpointing, it just throws:
"RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn"
btw code works well as gradient_checkpointing is closed.

so does opendelta support gradient_checkpointing?

@ShengdingHu
Copy link
Collaborator

opendelta supports bmtrain which utilizes gradient checkpointing. So which framework of gradient checkpointing do you use? Can you share a minimal reproduction code?

@hmzo
Copy link
Author

hmzo commented Oct 28, 2022

the codes are:

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model = AutoModelForSeq2SeqLM.from_pretrained("t5-large", cache_dir='tmp')
tokenizer = AutoTokenizer.from_pretrained("t5-large", cache_dir='tmp')
model.to('cuda')
model.train()
inputs = {k: v.to(model.device) for k, v in model.dummy_inputs.items()}
inputs.update({'labels': torch.ones_like(inputs['input_ids'])})

# original
l1 = model(**inputs).loss
l1.backward()
print("l1 bwd success ")

# delta
from opendelta import LoraModel
delta_model = LoraModel(backbone_model=model)
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
l2 = model(**inputs).loss
l2.backward()
print("l2 bwd success ")

# gradient_checkpointing
model.gradient_checkpointing_enable()
l3 = model(**inputs).loss
l3.backward()
print("l3 bwd success ")

and the exception is:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /workspace/OpenDelta/examples/bb3/test_gc.py:27 in <module>                                      │
│                                                                                                  │
│   24 # gradient_checkpointing                                                                    │
│   25 model.gradient_checkpointing_enable()                                                       │
│   26 l3 = model(**inputs).loss                                                                   │
│ ❱ 27 l3.backward()                                                                               │
│   28 print("l3 bwd success ")                                                                    │
│   29                                                                                             │
│   30                                                                                             │
│                                                                                                  │
│ /miniconda/lib/python3.9/site-packages/torch/_tensor.py:396 in backward                          │
│                                                                                                  │
│    393 │   │   │   │   retain_graph=retain_graph,                                                │
│    394 │   │   │   │   create_graph=create_graph,                                                │
│    395 │   │   │   │   inputs=inputs)                                                            │
│ ❱  396 │   │   torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=input  │
│    397 │                                                                                         │
│    398 │   def register_hook(self, hook):                                                        │
│    399 │   │   r"""Registers a backward hook.                                                    │
│                                                                                                  │
│ /miniconda/lib/python3.9/site-packages/torch/autograd/__init__.py:173 in backward                │
│                                                                                                  │
│   170 │   # The reason we repeat same the comment below is that                                  │
│   171 │   # some Python versions print out the first line of a multi-line function               │
│   172 │   # calls in the traceback and some print out the last line                              │
│ ❱ 173 │   Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the bac   │
│   174 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                        │
│   175 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to ru   │
│   176                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

@hmzo
Copy link
Author

hmzo commented Oct 31, 2022

The problem seems to be that the requires_grad flag is false when we pass inputs to the checkpoint function. my solution is to add some inelegant patches to overcome it.

import torch.nn.functional as F
from torch.overrides import has_torch_function_variadic, handle_torch_function
def gcpatch_embedding(input, weight, padding_idx=None, max_norm=None, 
              norm_type=2.0, scale_grad_by_freq=False, sparse=False):
    if has_torch_function_variadic(input, weight):
        return handle_torch_function(
            gcpatch_embedding, (input, weight),
            input, weight, padding_idx, max_norm, norm_type,
            scale_grad_by_freq, sparse
        )
    if padding_idx is not None:
        if padding_idx > 0:
            assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings"
        elif padding_idx < 0:
            assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings"
            padding_idx = weight.size(0) + padding_idx
    else:
        padding_idx = -1
    if max_norm is not None:
        input = input.contiguous()
        F._no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
    output = torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
    
    # patch here: set `requires_grad` flag to the output
    output.requires_grad_()
    return output

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model = AutoModelForSeq2SeqLM.from_pretrained("t5-large", cache_dir='tmp')
tokenizer = AutoTokenizer.from_pretrained("t5-large", cache_dir='tmp')
model.to('cuda')
model.train()
inputs = {k: v.to(model.device) for k, v in model.dummy_inputs.items()}
inputs.update({'labels': torch.ones_like(inputs['input_ids'])})

# original
l1 = model(**inputs).loss
l1.backward()
model.zero_grad()
print("l1 bwd success ")

# delta
from opendelta import LoraModel
delta_model = LoraModel(backbone_model=model)
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
l2 = model(**inputs).loss
l2.backward()
grad_dict2 = {}
for n,p in model.named_parameters():
    if p.requires_grad:
        grad_dict2[n] = p.grad
model.zero_grad()
print("l2 bwd success ")

# gradient_checkpointing
model.gradient_checkpointing_enable()
# patch here
F.embedding = gcpatch_embedding
l3 = model(**inputs).loss
l3.backward()
grad_dict3 = {}
for n,p in model.named_parameters():
    if p.requires_grad:
        grad_dict3[n] = p.grad
model.zero_grad()
print("l3 bwd success ")

is_equals = []
for n, p in model.named_parameters():
    if p.requires_grad:
        g2 = grad_dict2[n]
        g3 = grad_dict3[n]
        is_equals.append(g2.equal(g3))
print(f"grads are equal: {all(is_equals)}")

@telxt telxt closed this as completed Mar 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants