Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend context size without fine-tuning #705

Merged
merged 2 commits into from
Jul 5, 2023
Merged

Extend context size without fine-tuning #705

merged 2 commits into from
Jul 5, 2023

Conversation

airaria
Copy link
Contributor

@airaria airaria commented Jul 3, 2023

Description

Update: We find that NTK method mentioned in this Reddit post outperforms Position Interpolation up to a context size of at least 6K. Thus, with replace the implementation of PI with NTK method.

In addition, we use an empirical formula to set $\alpha$ adaptively given the input size, so that we could avoid hyperparameter tuning, and the method can be applied to different context sizes.

The following is the perplexity of Chinese-LLaMA-Plus-7B on a test set:

Context size 512 1024 2048 3072 4096 5120 6144
baseline 11.4 10.98 10.98 173.5 - - -
Position Interpolation 11.4 10.98 10.98 11.47 12.42 14.44 17.86
Adaptive NTK (this PR) 11.4 10.98 10.98 11.05 11.05 11.40 12.57

Even though Chinese-LLaMA-Plus-7B has been trained with input_length of 512, its context size can be extend to 5K~6K without significantly increasing the perplexity

Users only need to add the following lines to the beginning of the python code:

import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
    self.dim = dim
    self.base = base
    old_init(self, dim, max_position_embeddings, base, device)

def adaptive_ntk_forward(self, x, seq_len=None):
    if seq_len > self.max_seq_len_cached:
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        inv_freq = self.inv_freq
        dim = self.dim
        alpha = seq_len / 1024 - 1
        base = self.base * alpha ** (dim / (dim-2))
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

        freqs = torch.einsum("i,j->ij", t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
        cos_cached = emb.cos()[None, None, :, :]
        sin_cached = emb.sin()[None, None, :, :]
        return (
            cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
        )
    return (
        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
    )
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init

We keep the old implementation below for others' reference.


implementation of Position Interpolation (deprecated)

Description

We implement the Position Interpolation (proposed in the paper EXTENDING CONTEXT WINDOW OF LARGE LAN- GUAGE MODELS VIA POSITION INTERPOLATION and in the blog) for using LLaMA with Transformers.

We find that the method can be used out-of-the box even without training the model with long context size.
The following is the perplexity of Chinese-LLaMA-Plus-7B on a test set:

Context size 512 1024 2048 3072 4096 5120
Perplexity 11.4 11.0 11.0 11.5 12.4 15.6

Note that even though Chinese-LLaMA-Plus-7B has been trained with input_length of 512, its context window size can be extend to 4096 without significantly increasing the perplexity

Users only need to add the following lines to the beginning of the python code:

import transformers
def pi_forward(self, x, seq_len=None):
    if seq_len > self.max_seq_len_cached: # seq_len > 2048
        print(f"Perform position interpolation for length {seq_len}")
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        scale = self.max_seq_len_cached / seq_len
        t *= scale
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
        cos_cached = emb.cos()[None, None, :, :]
        sin_cached = emb.sin()[None, None, :, :]
        return (
            cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
        )
    return (
        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
    )
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = pi_forward

If seq_len<=2048, the behavior is not changed;
If seq_len>2048, the Position Interpolation is performed and the context size is extend to seq_len.

@airaria airaria marked this pull request as draft July 3, 2023 02:56
@airaria airaria changed the title Add Position Interpolation for inference scripts (draft) Add Position Interpolation for inference scripts Jul 3, 2023
ymcui added a commit that referenced this pull request Jul 3, 2023
@airaria airaria changed the title Add Position Interpolation for inference scripts Extend context size without fine-tuning Jul 5, 2023
@airaria airaria marked this pull request as ready for review July 5, 2023 09:14
@airaria airaria requested a review from ymcui July 5, 2023 09:19
@ymcui ymcui merged commit 6e007e0 into main Jul 5, 2023
1 check passed
@tkone2018
Copy link

@ymcui @airaria 请问这个是可以直接拿来用的是吗

@ymcui ymcui deleted the context_extend branch July 7, 2023 07:42
@xyfZzz
Copy link

xyfZzz commented Jul 12, 2023

只在推理阶段可以使用吗?在微调阶段有加上ntk吗?

@airaria
Copy link
Contributor Author

airaria commented Jul 12, 2023

只在推理阶段可以使用吗?在微调阶段有加上ntk吗?

目前已发布的模型中在微调阶段没有加上NTK。

@xyfZzz
Copy link

xyfZzz commented Jul 12, 2023

只在推理阶段可以使用吗?在微调阶段有加上ntk吗?

目前已发布的模型中在微调阶段没有加上NTK。

微调代码上目前有加上吗?这样我们可以在自己的数据上使用ntk进行微调

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants