Skip to content

Commit 354d6b5

Browse files
authored
[AINode] Quick bug fix patch (#15841)
1 parent ccba945 commit 354d6b5

File tree

6 files changed

+53
-19
lines changed

6 files changed

+53
-19
lines changed

iotdb-core/ainode/ainode/core/manager/model_manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
InvalidUriError,
2727
)
2828
from ainode.core.log import Logger
29+
from ainode.core.model.model_info import BuiltInModelType, ModelInfo, ModelStates
2930
from ainode.core.model.model_storage import ModelStorage
3031
from ainode.core.util.status import get_status
3132
from ainode.thrift.ainode.ttypes import (
@@ -140,3 +141,15 @@ def get_ckpt_path(self, model_id: str) -> str:
140141

141142
def show_models(self) -> TShowModelsResp:
142143
return self.model_storage.show_models()
144+
145+
def register_built_in_model(self, model_info: ModelInfo):
146+
self.model_storage.register_built_in_model(model_info)
147+
148+
def update_model_state(self, model_id: str, state: ModelStates):
149+
self.model_storage.update_model_state(model_id, state)
150+
151+
def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
152+
"""
153+
Get the type of the model with the given model_id.
154+
"""
155+
return self.model_storage.get_built_in_model_type(model_id.lower())

iotdb-core/ainode/ainode/core/model/model_storage.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ def delete_model(self, model_id: str) -> None:
207207
with self._lock_pool.get_lock(model_id).write_lock():
208208
if os.path.exists(storage_path):
209209
shutil.rmtree(storage_path)
210+
if model_id in self._model_info_map:
211+
del self._model_info_map[model_id]
212+
logger.info(f"Model {model_id} deleted successfully.")
210213

211214
def _is_built_in(self, model_id: str) -> bool:
212215
"""
@@ -218,9 +221,9 @@ def _is_built_in(self, model_id: str) -> bool:
218221
Returns:
219222
bool: True if the model is built-in, False otherwise.
220223
"""
221-
return (
222-
model_id in self._model_info_map
223-
and self._model_info_map[model_id].category == ModelCategory.BUILT_IN
224+
return model_id in self._model_info_map and (
225+
self._model_info_map[model_id].category == ModelCategory.BUILT_IN
226+
or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED
224227
)
225228

226229
def load_model(self, model_id: str, acceleration: bool) -> Callable:
@@ -291,3 +294,32 @@ def show_models(self) -> TShowModelsResp:
291294
for model_id, model_info in self._model_info_map.items()
292295
),
293296
)
297+
298+
def register_built_in_model(self, model_info: ModelInfo):
299+
with self._lock_pool.get_lock(model_info.model_id).write_lock():
300+
self._model_info_map[model_info.model_id] = model_info
301+
302+
def update_model_state(self, model_id: str, state: ModelStates):
303+
with self._lock_pool.get_lock(model_id).write_lock():
304+
if model_id in self._model_info_map:
305+
self._model_info_map[model_id].state = state
306+
else:
307+
raise ValueError(f"Model {model_id} does not exist.")
308+
309+
def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
310+
"""
311+
Get the type of the model with the given model_id.
312+
313+
Args:
314+
model_id (str): The ID of the model.
315+
316+
Returns:
317+
str: The type of the model.
318+
"""
319+
with self._lock_pool.get_lock(model_id).read_lock():
320+
if model_id in self._model_info_map:
321+
return get_built_in_model_type(
322+
self._model_info_map[model_id].model_type
323+
)
324+
else:
325+
raise ValueError(f"Model {model_id} does not exist.")

iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,8 @@ public TSStatus createModel(CreateModelPlan plan) {
121121
try {
122122
acquireModelTableWriteLock();
123123
String modelName = plan.getModelName();
124-
if (modelTable.containsModel(modelName)) {
125-
return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode())
126-
.setMessage(String.format("model [%s] has already been created.", modelName));
127-
} else {
128-
modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING));
129-
return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
130-
}
124+
modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING));
125+
return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
131126
} catch (Exception e) {
132127
final String errorMessage =
133128
String.format(

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ private TWindowParams getWindowParams() {
254254
}
255255

256256
private TsBlock preProcess(TsBlock inputTsBlock) {
257-
boolean notBuiltIn = !modelInferenceDescriptor.getModelInformation().isBuiltIn();
257+
// boolean notBuiltIn = !modelInferenceDescriptor.getModelInformation().isBuiltIn();
258+
boolean notBuiltIn = false;
258259
if (windowType == null || windowType == InferenceWindowType.HEAD) {
259260
if (notBuiltIn
260261
&& totalRow != modelInferenceDescriptor.getModelInformation().getInputShape()[0]) {

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -481,13 +481,6 @@ private void checkWindowSize(long windowSize, ModelInformation modelInformation)
481481
if (modelInformation.isBuiltIn()) {
482482
return;
483483
}
484-
485-
if (modelInformation.getInputShape()[0] != windowSize) {
486-
throw new SemanticException(
487-
String.format(
488-
"Window output %d is not equal to input size of model %d",
489-
windowSize, modelInformation.getInputShape()[0]));
490-
}
491484
}
492485

493486
private ISchemaTree analyzeSchema(

iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public ModelInformation(
8484
}
8585

8686
public ModelInformation(String modelName, ModelStatus status) {
87-
this.modelType = ModelType.USER_DEFINED;
87+
this.modelType = ModelType.BUILT_IN_FORECAST;
8888
this.modelName = modelName;
8989
this.inputShape = new int[0];
9090
this.outputShape = new int[0];

0 commit comments

Comments
 (0)