From 3e5a6abffc278ad9e4e611c7149457f0a1fa816c Mon Sep 17 00:00:00 2001 From: Alex Al-Saffar Date: Wed, 17 Apr 2024 14:07:10 +0930 Subject: [PATCH] made training call more robust. Same needs to be done for all other calls of train method --- src/vanna/base/base.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 5360f989..4aae648b 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -54,7 +54,7 @@ import sqlite3 import traceback from abc import ABC, abstractmethod -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Callable, Optional, ParamSpec, TypeVar from urllib.parse import urlparse import pandas as pd @@ -68,6 +68,31 @@ from ..utils import validate_config_path +PS = ParamSpec('PS') +T = TypeVar('T') + + +class RepeatUntilNoException: # alternatively, use: https://github.com/jd/tenacity + def __init__(self, retry: int = 3, sleep: float = 1.0, timeout: Optional[float] = None): + self.retry = retry + self.sleep = sleep + self.timeout = timeout + def __call__(self, func: Callable[PS, T]) -> Callable[PS, T]: + from functools import wraps + import time + @wraps(wrapped=func) + def wrapper(*args: PS.args, **kwargs: PS.kwargs): + for idx in range(self.retry): + try: + return func(*args, **kwargs) + except Exception as ex: + sleep_time = self.sleep * (idx + 1)**2 + print(f"""💥 Robust call of `{func}` failed with ```{ex}```.\nretrying {idx}/{self.retry} more times after sleeping for {sleep_time} seconds.""") + time.sleep(sleep_time) + raise RuntimeError(f"💥 Robust call failed after {self.retry} retries.") + return wrapper + + class VannaBase(ABC): def __init__(self, config=None): self.config = config @@ -1360,7 +1385,15 @@ def train( if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: self.add_ddl(item.item_value) elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: - self.add_documentation(item.item_value) + + @RepeatUntilNoException(retry=5, sleep=5.0) + def train_robust(): + self.add_documentation(item.item_value) + train_robust() + # self.add_documentation(item.item_value) + # without the robust call, this will fail when have extremely large (thousands or more) items leading (inevitably) to: SSL Connection Error: max retries exceeded with url ... + # error to be circumvented: HTTPSConnectionPool(host='ask.vanna.ai', port=443): Max retries exceeded with url: /rpc (Caused by SSLError(SSLEOFError(8, '[SSL: UNEXPECTED_EOF_WHILE_READING] EOF occurred in violation of protocol (_ssl.c:1006)'))) + elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL: self.add_question_sql(question=item.item_name, sql=item.item_value)