Skip to content

Commit

Permalink
Merge pull request #705 from ymcui/context_extend
Browse files Browse the repository at this point in the history
Extend context size (8K+) without fine-tuning
  • Loading branch information
ymcui committed Jul 5, 2023
2 parents 8e45406 + 8556b8f commit 6e007e0
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
30 changes: 30 additions & 0 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@
import traceback
import gc

import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.base = base
old_init(self, dim, max_position_embeddings, base, device)

def adaptive_ntk_forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
dim = self.dim
alpha = seq_len / 1024 - 1
base = self.base * alpha ** (dim / (dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
cos_cached = emb.cos()[None, None, :, :]
sin_cached = emb.sin()[None, None, :, :]
return (
cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init

# Parse command-line arguments
parser = argparse.ArgumentParser()
Expand Down
30 changes: 30 additions & 0 deletions scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,36 @@
from transformers import LlamaForCausalLM, LlamaTokenizer
from peft import PeftModel

import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.base = base
old_init(self, dim, max_position_embeddings, base, device)

def adaptive_ntk_forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
dim = self.dim
alpha = seq_len / 1024 - 1
base = self.base * alpha ** (dim / (dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
cos_cached = emb.cos()[None, None, :, :]
sin_cached = emb.sin()[None, None, :, :]
return (
cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init

generation_config = dict(
temperature=0.2,
Expand Down
31 changes: 31 additions & 0 deletions scripts/openai_server_demo/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,37 @@
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
from peft import PeftModel

import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.base = base
old_init(self, dim, max_position_embeddings, base, device)

def adaptive_ntk_forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
dim = self.dim
alpha = seq_len / 1024 - 1
base = self.base * alpha ** (dim / (dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
cos_cached = emb.cos()[None, None, :, :]
sin_cached = emb.sin()[None, None, :, :]
return (
cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init

from openai_api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down

0 comments on commit 6e007e0

Please sign in to comment.