diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index e426f59f..785c42eb 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -45,7 +45,6 @@ def __init__(self, config: Configuration): def insert(self, records: dict, options: InsertOptions = InsertOptions()): interface = InterfaceName.INSERT.value log_info(InfoMessages.INSERT_TRIGGERED.value, interface=interface) - self._checkConfig(interface) jsonBody = getInsertRequestBody(records, options) @@ -56,17 +55,27 @@ def insert(self, records: dict, options: InsertOptions = InsertOptions()): "Authorization": "Bearer " + self.storedToken, "sky-metadata": json.dumps(getMetrics()) } - - response = requests.post(requestURL, data=jsonBody, headers=headers) - processedResponse = processResponse(response) - result, partial = convertResponse(records, processedResponse, options) - if partial: - log_error(SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, interface) - elif 'records' not in result: - log_error(SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, interface) - else: - log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface) - return result + max_retries = 3 + # Use for-loop for retry logic, avoid code repetition + for attempt in range(max_retries+1): + try: + # If jsonBody is a dict, use json=, else use data= + response = requests.post(requestURL, data=jsonBody, headers=headers) + processedResponse = processResponse(response) + result, partial = convertResponse(records, processedResponse, options) + if partial: + log_error(SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, interface) + raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS, SkyflowErrorMessages.BATCH_INSERT_PARTIAL_SUCCESS.value, result, interface=interface) + if 'records' not in result: + log_error(SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, interface) + raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, SkyflowErrorMessages.BATCH_INSERT_FAILURE.value, result, interface=interface) + log_info(InfoMessages.INSERT_DATA_SUCCESS.value, interface) + return result + except requests.exceptions.ConnectionError as err: + if attempt < max_retries: + continue + else: + raise SkyflowError(SkyflowErrorCodes.SERVER_ERROR, f"Error occurred: {err}", interface=interface) def detokenize(self, records: dict, options: DetokenizeOptions = DetokenizeOptions()): interface = InterfaceName.DETOKENIZE.value