Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#19 from uservar/dev3
Browse files Browse the repository at this point in the history
Fix edge case for exactly 75 text tokens
  • Loading branch information
uservar committed Nov 25, 2022
2 parents 0b629af + fc5e7c3 commit 427e906
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def tokenize_line(self, line, used_custom_terms, hijack_comments):
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)

remade_tokens = remade_tokens + [self.id_eot] + [self.id_fill] * (tokens_to_add - 1)
remade_tokens = remade_tokens + [self.id_fill] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add

return remade_tokens, fixes, multipliers, token_count
Expand Down Expand Up @@ -306,18 +306,20 @@ def process_text_old(self, text):
i += embedding_length_in_tokens

if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
if hasattr(self.wrapped, "tokenizer"):
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
else:
vocab = {v: k for k, v in self.tokenizer.encoder.items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated")

token_count = len(remade_tokens)
remade_tokens = [self.id_sot] + remade_tokens[:maxlen - 2] + [self.id_eot]
remade_tokens += [self.id_fill] * (maxlen - len(remade_tokens))
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)

multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
multipliers = [1.0] + multipliers[:maxlen - 2] + [1.0]
multipliers += [1.0] * (maxlen - len(multipliers))

remade_batch_tokens.append(remade_tokens)
Expand Down Expand Up @@ -362,7 +364,7 @@ def forward(self, text):
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.id_eot] * 75)
tokens.append([self.id_fill] * 75)
multipliers.append([1.0] * 75)

z1 = self.process_tokens(tokens, multipliers)
Expand All @@ -379,6 +381,16 @@ def process_tokens(self, remade_batch_tokens, batch_multipliers):
remade_batch_tokens = [[self.id_sot] + x[:75] + [self.id_fill] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]

# put end_of_text token after the text tokens rather than the very end
for j, tokens in enumerate(remade_batch_tokens):
end_index = 76
for i, token in enumerate(tokens):
if token == self.id_fill:
end_index = i
break
tokens[end_index] = self.id_eot
batch_multipliers[j][end_index] = 1.0

if hasattr(self.wrapped, "transformer"):
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
Expand Down

0 comments on commit 427e906

Please sign in to comment.