Skip to content

Commit

Permalink
Model: Attempt to recreate generator on a fatal error
Browse files Browse the repository at this point in the history
If a job causes the generator to error, tabby stops working until
a relaunch. It's better to try establishing a system of redundancy
and remake the generator in the event that it fails.

May replace this with an exit signal for a fatal error instead, but
not sure.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Jul 15, 2024
1 parent 6019c93 commit 9dae461
Showing 1 changed file with 40 additions and 9 deletions.
49 changes: 40 additions & 9 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,15 +488,7 @@ async def load_gen(self, progress_callback=None, **kwargs):
yield value

# Create async generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
paged=self.paged,
)
await self.create_generator()

# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
Expand Down Expand Up @@ -645,6 +637,34 @@ def progress(loaded_modules: int, total_modules: int)
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)

async def create_generator(self):
try:
# Don't acquire locks unless a model is loaded
if self.model_loaded:
await self.load_lock.acquire()

# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)

# Create new generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
paged=self.paged,
)
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.model_loaded:
self.load_lock.release()

async with self.load_condition:
self.load_condition.notify_all()

def get_loras(self):
"""Convenience function to get all loras."""

Expand Down Expand Up @@ -1223,3 +1243,14 @@ async def generate_gen(
break
except asyncio.CancelledError:
await job.cancel()
except Exception as ex:
# Create a new generator since the current state is broken
# No need to wait for this to finish
logger.error(
"FATAL ERROR with generation. "
"Attempting to recreate the generator. "
"If this fails, please restart the server.\n"
)
asyncio.ensure_future(self.create_generator())

raise ex

0 comments on commit 9dae461

Please sign in to comment.