Skip to content

Commit

Permalink
Diff converter v2 (huggingface#30868)
Browse files Browse the repository at this point in the history
* current working example!

* commit regex and result file

* update

* nit

* push the conversion file

* oups

* roadmap and nits

* attempt diffs for 3 files

* persimmon

* nit

* add diff file that is the same as the modeling_llama.py

* fix rope nits

* updates

* updates with converted versions

* give some breathing space to the code

* delete

* update

* update

* push the actual result

* update regex patterns

* update regex patterns

* fix some issues

* fix some issues

* fix some issues

* updates

* updates

* updates

* updates

* updates

* revert changes done to llama

* updates

* update gemma

* updates

* oups

* current state

* current state

* update

* ouiiii

* nit

* clear diffs

* nit

* fixup

* update

* doc 🚀

* 🔥

* for now use gemma

* deal with comments

* style

* handle funtions

* deal with assigns

* todos

* process inheritage

* keep decorators?

* 🤗

* deal with duplicates

* fixup

* correctly remove duplicate code

* run ruff post script

* ruff deals pretty well with imports, let's leave it to him

* ah maybe not lol

* for now remove all imports from child.

* nit

* conversion of llama

* okay

* convert starcoder2

* synch with main

* update llama diff

* updates

* https://docs.astral.sh/ruff/rules/redefined-while-unused/ fixes the imports, bit needs later version of ruff

* updates

* okay actual state

* non zero exit

* update!

* revert unrelated

* remove other diff files

* updates

* cleanup

* update

* less diff!

* stash

* current updates

* updates

* No need for call

* finished fining deps

* update

* current changes

* current state

* current state

* new status

* nit

* finally

* fixes

* nits

* order is now expected

* use logger info instead of prints

* fixup

* up

* nit

* update

* nits

* update

* correct merge

* update

* update

* update

* add warning

* update caution message

* update

* better merging strategy

* copy class statements :wink

* fixups

* nits

* update

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* nits

* smaller header

* do cleanup some stuff

* even simpler header?

* fixup

* updates

* ruff

* update examples

* nit

* TODO

* state

* OUUUUUUF

* current state

* nits

* final state

* add a readme

* fixup

* remove diff llama

* fix

* nit

* dummy noy funny

* ruff format tests src utils --check

* everless diffs

* less diffs and fix test

* fixes

* naming nit?

* update converter and add supper example

* nits

* updated for function signatures

* update

* update

* add converted dummies

* autoformat

* single target assign fix

* fixup

* fix some imports

* fixes

* don't push them

* `# noqa: F841`

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and vasqu committed Jun 1, 2024
1 parent ba34b39 commit b3261f5
Show file tree
Hide file tree
Showing 13 changed files with 1,315 additions and 62 deletions.
20 changes: 20 additions & 0 deletions examples/diff-conversion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Using the `diff_converter` linter

`pip install libcst` is a must!

# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs

The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`.

Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage.

`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"`

## How it works
We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies.

The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file.
We use ruff to automatically remove the potential duplicate imports.

## Why we use libcst instead of the native AST?
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`
10 changes: 10 additions & 0 deletions examples/diff-conversion/convert_examples.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

# Iterate over each file in the current directory
for file in examples/diff-conversion/diff_*; do
# Check if it's a regular file
if [ -f "$file" ]; then
# Call the Python script with the file name as an argument
python utils/diff_model_converter.py --files_to_parse "$file"
fi
done
44 changes: 44 additions & 0 deletions examples/diff-conversion/diff_dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from math import log
from typing import List, Optional, Tuple, Union

import torch

from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel


def _pre_process_input(input_ids):
print(log(input_ids))
return input_ids


# example where we need some deps and some functions
class DummyModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
input_ids = _pre_process_input(input_ids)

return super().forward(
None,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)
14 changes: 14 additions & 0 deletions examples/diff-conversion/diff_my_new_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from transformers.models.llama.configuration_llama import LlamaConfig


# Example where we only want to only add a new config argument and new arg doc
# here there is no `ARG` so we are gonna take parent doc
class MyNewModelConfig(LlamaConfig):
r"""
mlp_bias (`bool`, *optional*, defaults to `False`)
"""

def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):
self.mlp_bias = mlp_bias
self.new_param = new_param
super().__init__(self, **super_kwargs)
31 changes: 31 additions & 0 deletions examples/diff-conversion/diff_my_new_model2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from transformers.models.gemma.modeling_gemma import GemmaForSequenceClassification
from transformers.models.llama.configuration_llama import LlamaConfig


# Example where we only want to only modify the docstring
class MyNewModel2Config(LlamaConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""


# Example where alllllll the dependencies are fetched to just copy the entire class
class MyNewModel2ForSequenceClassification(GemmaForSequenceClassification):
pass
30 changes: 30 additions & 0 deletions examples/diff-conversion/diff_new_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Example where we only want to overwrite the defaults of an init

from transformers.models.gemma.configuration_gemma import GemmaConfig


class NewModelConfig(GemmaConfig):
def __init__(
self,
vocab_size=256030,
hidden_size=64,
intermediate_size=90,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=1500,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
):
super().__init__(self)
38 changes: 38 additions & 0 deletions examples/diff-conversion/diff_super.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import List, Optional, Tuple, Union

import torch

from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel


# example where we need some deps and some functions
class SuperModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
out = super().forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)
out.logits *= 2**4
return out
23 changes: 9 additions & 14 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,27 +19,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma model configuration"""

from ...configuration_utils import PretrainedConfig
from ...utils import logging


logger = logging.get_logger(__name__)
from transformers import PretrainedConfig


class GemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
Expand Down Expand Up @@ -83,16 +82,12 @@ class GemmaConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
Expand Down
Loading

0 comments on commit b3261f5

Please sign in to comment.