Skip to content

Commit

Permalink
ML model improvement : Adding "SHOW MODELS and DESCRIBE MODEL"
Browse files Browse the repository at this point in the history
Author:    rajagurunath <gurunathrajagopal@gmail.com>
Date:      Mon May 24 02:37:40 2021 +0530
  • Loading branch information
Gurunath LankupalliVenugopal authored and rajagurunath committed May 24, 2021
1 parent 3e106e3 commit 4a9ee68
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 7 deletions.
2 changes: 2 additions & 0 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __init__(self):
RelConverter.add_plugin_class(custom.ShowColumnsPlugin, replace=False)
RelConverter.add_plugin_class(custom.ShowSchemasPlugin, replace=False)
RelConverter.add_plugin_class(custom.ShowTablesPlugin, replace=False)
RelConverter.add_plugin_class(custom.ShowModelsPlugin, replace=False)
RelConverter.add_plugin_class(custom.ShowModelParamsPlugin, replace=False)

RexConverter.add_plugin_class(core.RexCallPlugin, replace=False)
RexConverter.add_plugin_class(core.RexInputRefPlugin, replace=False)
Expand Down
4 changes: 4 additions & 0 deletions dask_sql/physical/rel/custom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from .create_model import CreateModelPlugin
from .create_table import CreateTablePlugin
from .create_table_as import CreateTableAsPlugin
from .describe_model import ShowModelParamsPlugin
from .drop_model import DropModelPlugin
from .drop_table import DropTablePlugin
from .predict import PredictModelPlugin
from .schemas import ShowSchemasPlugin
from .show_models import ShowModelsPlugin
from .tables import ShowTablesPlugin

__all__ = [
Expand All @@ -20,4 +22,6 @@
ShowColumnsPlugin,
ShowSchemasPlugin,
ShowTablesPlugin,
ShowModelsPlugin,
ShowModelParamsPlugin,
]
33 changes: 33 additions & 0 deletions dask_sql/physical/rel/custom/describe_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import dask.dataframe as dd
import pandas as pd

from dask_sql.datacontainer import ColumnContainer, DataContainer
from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.utils import get_model_from_compound_identifier


class ShowModelParamsPlugin(BaseRelPlugin):
"""
Show all Params used to train a given model and training columns.
The SQL is:
DESCRIBE MODEL <model_name>
The result is also a table, although it is created on the fly.
"""

class_name = "com.dask.sql.parser.SqlShowModelParams"

def convert(
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
) -> DataContainer:
components = list(map(str, sql.getTable().names))
model, training_columns = get_model_from_compound_identifier(
context, components
)
model_params = model.get_params()
model_params["training_columns"] = training_columns.tolist()
df = pd.DataFrame.from_dict(model_params, orient="index", columns=["Params"])
cc = ColumnContainer(df.columns)
dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)
return dc
28 changes: 28 additions & 0 deletions dask_sql/physical/rel/custom/show_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import dask.dataframe as dd
import pandas as pd

from dask_sql.datacontainer import ColumnContainer, DataContainer
from dask_sql.physical.rel.base import BaseRelPlugin


class ShowModelsPlugin(BaseRelPlugin):
"""
Show all MODELS currently registered/trained.
The SQL is:
SHOW MODELS
The result is also a table, although it is created on the fly.
"""

class_name = "com.dask.sql.parser.SqlShowModels"

def convert(
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
) -> DataContainer:

df = pd.DataFrame({"Models": list(context.models.keys())})

cc = ColumnContainer(df.columns)
dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)
return dc
18 changes: 17 additions & 1 deletion dask_sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple
from unittest.mock import patch
from uuid import uuid4

Expand Down Expand Up @@ -218,6 +218,22 @@ def get_table_from_compound_identifier(
raise AttributeError(f"Table {tableName} is not defined.")


def get_model_from_compound_identifier(
context: "dask_sql.Context", components: List[str]
) -> Tuple:
"""
Helper function to return the correct model
from the trained models in the context
with the given name
"""
modelName = components[-1]

try:
return context.models[modelName]
except KeyError:
raise AttributeError(f"Model {modelName} is not defined.")


def convert_sql_kwargs(
sql_kwargs: "java.util.HashMap[org.apache.calcite.sql.SqlNode, org.apache.calcite.sql.SqlNode]",
) -> Dict[str, Any]:
Expand Down
43 changes: 40 additions & 3 deletions notebooks/Feature Overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,36 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sql\n",
"SHOW MODELS"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sql\n",
"DESCRIBE MODEL my_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sql\n",
"DESCRIBE TABLE training_data"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -571,13 +601,20 @@
"\"\"\").compute() \n",
"t.set_index([\"target\", \"species\"]).unstack(\"species\").number.plot.bar()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "dask-sql",
"language": "python",
"name": "python3"
"name": "dask-sql"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -589,7 +626,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
"version": "3.9.4"
},
"toc-autonumbering": false,
"toc-showcode": false,
Expand Down
4 changes: 4 additions & 0 deletions planner/src/main/codegen/config.fmpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ data: {
"com.dask.sql.parser.SqlShowColumns",
"com.dask.sql.parser.SqlShowSchemas",
"com.dask.sql.parser.SqlShowTables"
"com.dask.sql.parser.SqlShowModels"
]

# List of keywords.
Expand All @@ -49,6 +50,7 @@ data: {
"SCHEMAS"
"STATISTICS"
"TABLES"
"MODELS"
]

# The keywords can only be used in a specific context,
Expand All @@ -73,6 +75,8 @@ data: {
"SqlShowSchemas()"
"SqlShowTables()"
"SqlPredictModel()"
"SqlShowModels()"
"SqlDescribeModel()"
]

createStatementParserMethods: [
Expand Down
30 changes: 27 additions & 3 deletions planner/src/main/codegen/includes/show.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,33 @@ SqlNode SqlShowColumns() :
}
}

// DESCRIBE "table"
// DESCRIBE TABLE "table"
SqlNode SqlDescribeTable() :
{
final Span s;
final SqlIdentifier schemaName;
final SqlIdentifier tableName;
}
{
<DESCRIBE> { s = span(); }
<DESCRIBE> { s = span(); } <TABLE>
tableName = CompoundTableIdentifier()
{
return new SqlShowColumns(s.end(this), tableName);
}git
}
// DESCRIBE MODEL "model_name"
SqlNode SqlDescribeModel() :
{
final Span s;
final SqlIdentifier modelName;
}
{
<DESCRIBE> { s = span(); } <MODEL>
modelName = CompoundTableIdentifier()
{
return new SqlShowModelParams(s.end(this), modelName);
}
}

// ANALYZE TABLE table_identifier COMPUTE STATISTICS [ FOR COLUMNS col [ , ... ] | FOR ALL COLUMNS ]
SqlNode SqlAnalyzeTable() :
{
Expand All @@ -86,3 +98,15 @@ SqlNode SqlAnalyzeTable() :
return new SqlAnalyzeTable(s.end(this), tableName, columnList);
}
}

// SHOW MODELS
SqlNode SqlShowModels() :
{
final Span s;
}
{
<SHOW> { s = span(); } <MODELS>
{
return new SqlShowModels(s.end(this));
}
}
22 changes: 22 additions & 0 deletions planner/src/main/java/com/dask/sql/parser/SqlShowModelParams.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.dask.sql.parser;

import org.apache.calcite.sql.SqlDescribeTable;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlWriter;
import org.apache.calcite.sql.parser.SqlParserPos;

public class SqlShowModelParams extends SqlDescribeTable {
SqlIdentifier modelName;

public SqlShowModelParams(SqlParserPos pos, SqlIdentifier modelName) {
super(pos, modelName,null);
this.modelName = modelName;
}

@Override
public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
writer.keyword("DESCRIBE");
writer.keyword("MODEL");
this.modelName.unparse(writer, leftPrec, rightPrec);
}
}
37 changes: 37 additions & 0 deletions planner/src/main/java/com/dask/sql/parser/SqlShowModels.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.dask.sql.parser;

import java.util.ArrayList;
import java.util.List;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlWriter;
import org.apache.calcite.sql.SqlNode;

import org.apache.calcite.sql.parser.SqlParserPos;

public class SqlShowModels extends SqlCall {
public SqlIdentifier catalog;
public SqlNode like;


public SqlShowModels(SqlParserPos pos) {
super(pos);
}
public SqlOperator getOperator() {
throw new UnsupportedOperationException();
}

public List<SqlNode> getOperandList() {
ArrayList<SqlNode> operandList = new ArrayList<SqlNode>();
return operandList;
}

@Override
public void unparse(SqlWriter writer,int leftPrec, int rightPrec) {
writer.keyword("SHOW");
writer.keyword("MODELS");

}
}

0 comments on commit 4a9ee68

Please sign in to comment.