Skip to content

Commit

Permalink
Option to only return completion from generate_simple
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Apr 7, 2024
1 parent 8923fea commit 6b14f8a
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions exllamav2/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def generate_simple(self,
stop_token: int or None = -1,
add_bos: bool = False,
abort_event: threading.Event | None = None,
input_embeddings: torch.Tensor | None = None):
input_embeddings: torch.Tensor | None = None,
completion_only: bool = False):

"""
Generate one or more completions.
Expand Down Expand Up @@ -114,6 +115,9 @@ def generate_simple(self,
is not supported when passing input embeddings unless all prompts are the same. Prompt must
contain the string `{{EMBED_HERE}}` to indicate where embeddings are to be inserted.
:param completion_only:
Only return completion. If False, returned string will include the input prompt.
:return:
Completion(s) (str or list[str] depending on the type of the input prompt argument)
"""
Expand Down Expand Up @@ -172,6 +176,7 @@ def generate_simple(self,
encode_special_tokens = encode_special_tokens,
return_offsets = True,
add_bos = add_bos)

if prompts_identical:
position_offsets = None

Expand All @@ -185,6 +190,11 @@ def generate_simple(self,

first_token = max(-overflow, 0)

# Completion only

if completion_only:
first_token = ids.shape[-1]

# Prepare for healing

unhealed_token = None
Expand Down Expand Up @@ -212,6 +222,7 @@ def generate_simple(self,

# Begin filters

healed_token = []
id_to_piece = self.tokenizer.get_id_to_piece_list()
if unhealed_token is not None:
unhealed_token_list = unhealed_token.flatten().tolist()
Expand Down Expand Up @@ -243,6 +254,10 @@ def generate_simple(self,
self.tokenizer,
prefix_token = unhealed_token)

if unhealed_token is not None:
unhealed_token_copy = unhealed_token
healed_token = token

if stop_token is not None:
for b in range(batch_size):
if token[b, 0].item() == stop_token:
Expand Down Expand Up @@ -280,10 +295,20 @@ def generate_simple(self,
decode_ids = self.sequence_ids[:, first_token:]
if input_embeddings is not None:
decode_ids = torch.stack([decode_ids[i][decode_ids[i] != self.tokenizer.pad_token_id] for i in range(batch_size)])

if len(healed_token) and completion_only:
decode_ids = torch.cat([healed_token, decode_ids], dim = -1)

text = self.tokenizer.decode(decode_ids, decode_special_tokens = decode_special_tokens)

if isinstance(prompt, str): return text[0]
return text
if len(healed_token) and completion_only:
pre_text = self.tokenizer.decode(unhealed_token_copy, decode_special_tokens = decode_special_tokens)
text = [t[len(p):] for t, p in zip(text, pre_text)]

if isinstance(prompt, str):
return text[0]
else:
return text


def _gen_begin_base(self,
Expand Down

0 comments on commit 6b14f8a

Please sign in to comment.