diff --git a/src/Expressions/ColumnExpression.php b/src/Expressions/ColumnExpression.php index 7e8f6a6..f0e8153 100644 --- a/src/Expressions/ColumnExpression.php +++ b/src/Expressions/ColumnExpression.php @@ -8,7 +8,7 @@ final class ColumnExpression extends Expression { private string $columnExpression; private string $columnName; - private ?string $tableName; + public ?string $tableName; private ?string $databaseName; private bool $allowFallthrough = false; @@ -90,10 +90,6 @@ public function isWellFormed(): bool { return true; } - public function tableName(): ?string { - return $this->tableName; - } - public function prefixColumnExpression(string $prefix): void { if (!Str\starts_with($this->columnExpression, $prefix)) { $this->columnExpression = $prefix.$this->columnExpression; diff --git a/src/Query/FromClause.php b/src/Query/FromClause.php index 2e1fec4..923e159 100644 --- a/src/Query/FromClause.php +++ b/src/Query/FromClause.php @@ -2,7 +2,7 @@ namespace Slack\SQLFake; -use namespace HH\Lib\{C, Vec}; +use namespace HH\Lib\{C, Dict, Keyset, Vec}; /** * Represents the entire FROM clause of a query, @@ -37,16 +37,21 @@ public function aliasRecentExpression(string $name): void { public function process( AsyncMysqlConnection $conn, string $sql, - ): (dataset, unique_index_refs, index_refs, vec) { + ): (dataset, unique_index_refs, index_refs, vec, dict) { $data = dict[]; $is_first_table = true; $unique_index_refs = dict[]; $index_refs = dict[]; $indexes = vec[]; + $columns = dict[]; foreach ($this->tables as $table) { $schema = null; + $new_unique_index_refs = dict[]; + $new_index_refs = dict[]; + $new_indexes = vec[]; + if (Shapes::keyExists($table, 'subquery')) { $res = $table['subquery']->evaluate(dict[], $conn); invariant($res is KeyedContainer<_, _>, 'evaluated result of SubqueryExpression must be dataset'); @@ -60,14 +65,48 @@ public function process( $name = $table['alias'] ?? $table_name; $schema = QueryContext::getSchema($database, $table_name); if ($schema === null && QueryContext::$strictSchemaMode) { - throw new SQLFakeRuntimeException("Table $table_name not found in schema and strict mode is enabled"); + throw new SQLFakeRuntimeException( + "Table $table_name not found in schema and strict mode is enabled", + ); } - list($res, $unique_index_refs, $index_refs) = + list($res, $new_unique_index_refs, $new_index_refs) = $conn->getServer()->getTableData($database, $table_name) ?: tuple(dict[], dict[], dict[]); + + if (C\count($this->tables) > 1) { + $new_unique_index_refs = Dict\map_keys($new_unique_index_refs, $k ==> $name.'.'.$k); + $new_index_refs = Dict\map_keys($new_index_refs, $k ==> $name.'.'.$k); + } + if ($schema is nonnull) { - $indexes = Vec\concat($indexes, $schema->indexes); + if (C\count($this->tables) > 1) { + $new_indexes = Vec\map( + $schema->indexes, + $index ==> new Index( + $name.'.'.$index->name, + 'INDEX', + Keyset\map($index->fields, $k ==> $name.'.'.$k), + ), + ); + } else { + $new_indexes = $schema->indexes; + } + + $new_columns = dict[]; + + foreach ($schema->fields as $field) { + if (C\count($this->tables) > 1) { + $new_columns[$name.'.'.$field->name] = $field; + } else { + $new_columns[$field->name] = $field; + } + } + + $columns = Dict\merge($columns, $new_columns); } + + $unique_index_refs = Dict\merge($unique_index_refs, $new_unique_index_refs); + $index_refs = Dict\merge($index_refs, $new_index_refs); } $new_dataset = dict[]; @@ -113,26 +152,30 @@ public function process( if ($data || !$is_first_table) { // do the join here. based on join type, pass in $data and $res to filter. and aliases - $data = JoinProcessor::process( + list($data, $unique_index_refs, $index_refs) = JoinProcessor::process( $conn, - $data, - $new_dataset, + tuple($data, $unique_index_refs, $index_refs), + tuple($new_dataset, $new_unique_index_refs, $new_index_refs), $name, $table['join_type'], $table['join_operator'] ?? null, $table['join_expression'] ?? null, $schema, + $indexes, + $new_indexes, ); } else { $data = $new_dataset; } + $indexes = Vec\concat($indexes, $new_indexes); + if ($is_first_table) { Metrics::trackQuery(QueryType::SELECT, $conn->getServer()->name, $name, $sql); $is_first_table = false; } } - return tuple($data, $unique_index_refs, $index_refs, $indexes); + return tuple($data, $unique_index_refs, $index_refs, $indexes, $columns); } } diff --git a/src/Query/JoinProcessor.php b/src/Query/JoinProcessor.php index 9c6a777..027a533 100644 --- a/src/Query/JoinProcessor.php +++ b/src/Query/JoinProcessor.php @@ -14,15 +14,16 @@ public static function process( AsyncMysqlConnection $conn, - dataset $left_dataset, - dataset $right_dataset, + table_data $left_dataset, + table_data $right_dataset, string $right_table_name, JoinType $join_type, ?JoinOperator $_ref_type, ?Expression $ref_clause, ?TableSchema $right_schema, - ): dataset { - + vec $left_indexes, + vec $right_indexes, + ): table_data { // MySQL supports JOIN (inner), LEFT OUTER JOIN, RIGHT OUTER JOIN, and implicitly CROSS JOIN (which uses commas), NATURAL // conditions can be specified with ON or with USING () // does not support FULL OUTER JOIN @@ -36,8 +37,8 @@ public static function process( // instead of evaluating the same expressions over and over again in nested loops, we can optimize this for a more efficient algorithm // this is somewhat experimental and different merge strategies could be applied in more situations in the future if ( - C\count($left_dataset) > 5 && - C\count($right_dataset) > 5 && + C\count($left_dataset[0]) > 5 && + C\count($right_dataset[0]) > 5 && $filter is BinaryOperatorExpression && $filter->left is ColumnExpression && $filter->right is ColumnExpression && @@ -53,19 +54,29 @@ public static function process( $_ref_type, $filter, $right_schema, + $left_indexes, + $right_indexes, ); } + $left_mappings = dict[]; + $right_mappings = dict[]; + switch ($join_type) { case JoinType::JOIN: case JoinType::STRAIGHT: // straight join is just a query planner optimization of INNER JOIN, // and it is actually what we are doing here anyway - foreach ($left_dataset as $row) { - foreach ($right_dataset as $r) { + foreach ($left_dataset[0] as $left_row_id => $row) { + foreach ($right_dataset[0] as $right_row_id => $r) { $candidate_row = Dict\merge($row, $r); if ((bool)$filter->evaluate($candidate_row, $conn)) { $out[] = $candidate_row; + $insert_id = C\count($out) - 1; + $left_mappings[$left_row_id] ??= keyset[]; + $left_mappings[$left_row_id][] = $insert_id; + $right_mappings[$right_row_id] ??= keyset[]; + $right_mappings[$right_row_id][] = $insert_id; } } } @@ -81,12 +92,17 @@ public static function process( } } - foreach ($left_dataset as $row) { + foreach ($left_dataset[0] as $left_row_id => $row) { $any_match = false; - foreach ($right_dataset as $r) { + foreach ($right_dataset[0] as $right_row_id => $r) { $candidate_row = Dict\merge($row, $r); if ((bool)$filter->evaluate($candidate_row, $conn)) { $out[] = $candidate_row; + $insert_id = C\count($out) - 1; + $left_mappings[$left_row_id] ??= keyset[]; + $left_mappings[$left_row_id][] = $insert_id; + $right_mappings[$right_row_id] ??= keyset[]; + $right_mappings[$right_row_id][] = $insert_id; $any_match = true; } } @@ -114,13 +130,18 @@ public static function process( } } - foreach ($right_dataset as $raw) { + foreach ($right_dataset[0] as $left_row_id => $raw) { $any_match = false; - foreach ($left_dataset as $row) { + foreach ($left_dataset[0] as $right_row_id => $row) { $candidate_row = Dict\merge($row, $raw); if ((bool)$filter->evaluate($candidate_row, $conn)) { $out[] = $candidate_row; $any_match = true; + $insert_id = C\count($out) - 1; + $left_mappings[$left_row_id] ??= keyset[]; + $left_mappings[$left_row_id][] = $insert_id; + $right_mappings[$right_row_id] ??= keyset[]; + $right_mappings[$right_row_id][] = $insert_id; } } @@ -131,30 +152,49 @@ public static function process( } break; case JoinType::CROSS: - foreach ($left_dataset as $row) { - foreach ($right_dataset as $r) { + foreach ($left_dataset[0] as $left_row_id => $row) { + foreach ($right_dataset[0] as $right_row_id => $r) { $out[] = Dict\merge($row, $r); + $insert_id = C\count($out) - 1; + $left_mappings[$left_row_id] ??= keyset[]; + $left_mappings[$left_row_id][] = $insert_id; + $right_mappings[$right_row_id] ??= keyset[]; + $right_mappings[$right_row_id][] = $insert_id; } } break; case JoinType::NATURAL: // unlike other join filters this one has to be built at runtime, using the list of columns that exists between the two tables // for each column in the target table, see if there is a matching column in the rest of the data set. if so, make a filter that they must be equal. - $filter = self::buildNaturalJoinFilter($left_dataset, $right_dataset); + $filter = self::buildNaturalJoinFilter($left_dataset[0], $right_dataset[0]); // now basically just do a regular join - foreach ($left_dataset as $row) { - foreach ($right_dataset as $r) { + foreach ($left_dataset[0] as $left_row_id => $row) { + foreach ($right_dataset[0] as $right_row_id => $r) { $candidate_row = Dict\merge($row, $r); if ((bool)$filter->evaluate($candidate_row, $conn)) { $out[] = $candidate_row; + $insert_id = C\count($out) - 1; + $left_mappings[$left_row_id] ??= keyset[]; + $left_mappings[$left_row_id][] = $insert_id; + $right_mappings[$right_row_id] ??= keyset[]; + $right_mappings[$right_row_id][] = $insert_id; } } } break; } - return dict($out); + $index_refs = self::getIndexRefsFromMappings( + $left_dataset, + $right_dataset, + $left_mappings, + $right_mappings, + $left_indexes, + $right_indexes, + ); + + return tuple(dict($out), dict[], $index_refs); } /** @@ -180,7 +220,9 @@ protected static function buildNaturalJoinFilter(dataset $left_dataset, dataset // MySQL actually doesn't throw if there's no matching columns, but I think we can take the liberty to assume it's not what you meant to do and throw here if ($filter === null) { - throw new SQLFakeParseException('NATURAL join keyword was used with tables that do not share any column names'); + throw new SQLFakeParseException( + 'NATURAL join keyword was used with tables that do not share any column names', + ); } return $filter; @@ -195,10 +237,12 @@ protected static function addJoinFilterExpression( string $right_column, ): BinaryOperatorExpression { - $left = - new ColumnExpression(shape('type' => TokenType::IDENTIFIER, 'value' => $left_column, 'raw' => $left_column)); - $right = - new ColumnExpression(shape('type' => TokenType::IDENTIFIER, 'value' => $right_column, 'raw' => $right_column)); + $left = new ColumnExpression( + shape('type' => TokenType::IDENTIFIER, 'value' => $left_column, 'raw' => $left_column), + ); + $right = new ColumnExpression( + shape('type' => TokenType::IDENTIFIER, 'value' => $right_column, 'raw' => $right_column), + ); // making a binary expression ensuring those two tokens are equal $expr = new BinaryOperatorExpression($left, /* $negated */ false, Operator::EQUALS, $right); @@ -229,17 +273,19 @@ private static function coerceToArrayKey(mixed $value): arraykey { */ private static function processHashJoin( AsyncMysqlConnection $conn, - dataset $left_dataset, - dataset $right_dataset, + table_data $left_dataset, + table_data $right_dataset, string $right_table_name, JoinType $join_type, ?JoinOperator $_ref_type, BinaryOperatorExpression $filter, ?TableSchema $right_schema, - ): dataset { + vec $left_indexes, + vec $right_indexes, + ): table_data { $left = $filter->left as ColumnExpression; $right = $filter->right as ColumnExpression; - if ($left->tableName() === $right_table_name) { + if ($left->tableName === $right_table_name) { // filter order may not match table order // if the left filter is for the right table, swap the filters list($left, $right) = vec[$right, $left]; @@ -249,21 +295,29 @@ private static function processHashJoin( // evaluate the column expression once per row in the right dataset first, building up a temporary table that groups all rows together for each value // multiple rows may have the same value. their ids in the original dataset are stored in a keyset $right_temp_table = dict[]; - foreach ($right_dataset as $k => $r) { + foreach ($right_dataset[0] as $k => $r) { $value = $right->evaluate($r, $conn); $value = self::coerceToArrayKey($value); $right_temp_table[$value] ??= keyset[]; $right_temp_table[$value][] = $k; } + $left_mappings = dict[]; + $right_mappings = dict[]; + switch ($join_type) { case JoinType::JOIN: case JoinType::STRAIGHT: - foreach ($left_dataset as $row) { + foreach ($left_dataset[0] as $left_row_id => $row) { $value = $left->evaluate($row, $conn) |> static::coerceToArrayKey($$); // find all rows matching this value in the right temp table and get their full rows foreach ($right_temp_table[$value] ?? keyset[] as $k) { - $out[] = Dict\merge($row, $right_dataset[$k]); + $out[] = Dict\merge($row, $right_dataset[0][$k]); + $insert_id = C\count($out) - 1; + $left_mappings[$left_row_id] ??= keyset[]; + $left_mappings[$left_row_id][] = $insert_id; + $right_mappings[$k] ??= keyset[]; + $right_mappings[$k][] = $insert_id; } } break; @@ -278,12 +332,17 @@ private static function processHashJoin( } } - foreach ($left_dataset as $row) { + foreach ($left_dataset[0] as $left_row_id => $row) { $any_match = false; $value = $left->evaluate($row, $conn) |> static::coerceToArrayKey($$); - foreach ($right_dataset as $r) { + foreach ($right_dataset[0] as $r) { foreach ($right_temp_table[$value] ?? keyset[] as $k) { - $out[] = Dict\merge($row, $right_dataset[$k]); + $out[] = Dict\merge($row, $right_dataset[0][$k]); + $insert_id = C\count($out) - 1; + $left_mappings[$left_row_id] ??= keyset[]; + $left_mappings[$left_row_id][] = $insert_id; + $right_mappings[$k] ??= keyset[]; + $right_mappings[$k][] = $insert_id; $any_match = true; } } @@ -297,12 +356,97 @@ private static function processHashJoin( } else { $out[] = $row; } + + $insert_id = C\count($out) - 1; + $left_mappings[$left_row_id] ??= keyset[]; + $left_mappings[$left_row_id][] = $insert_id; } } break; default: invariant_violation('unreachable'); } - return dict($out); + + $index_refs = self::getIndexRefsFromMappings( + $left_dataset, + $right_dataset, + $left_mappings, + $right_mappings, + $left_indexes, + $right_indexes, + ); + + return tuple(dict($out), dict[], $index_refs); + } + + private static function getIndexRefsFromMappings( + table_data $left_dataset, + table_data $right_dataset, + dict> $left_mappings, + dict> $right_mappings, + vec $left_indexes, + vec $right_indexes, + ): index_refs { + $index_refs = dict[]; + + foreach ($left_mappings as $left_row_id => $new_pks) { + foreach ($left_indexes as $left_index) { + if (Str\ends_with($left_index->name, '.PRIMARY')) { + $index_refs[$left_index->name] ??= dict[]; + $index_refs[$left_index->name][$left_row_id] = $new_pks; + } + } + } + + foreach ($right_mappings as $right_row_id => $new_pks) { + foreach ($right_indexes as $right_index) { + if (Str\ends_with($right_index->name, '.PRIMARY')) { + $index_refs[$right_index->name] ??= dict[]; + $index_refs[$right_index->name][$right_row_id] = $new_pks; + } + } + } + + foreach ($left_dataset[1] as $left_index_name => $left_index_refs) { + foreach ($left_index_refs as $left_index_key => $left_index_pk) { + if (isset($left_mappings[$left_index_pk])) { + $index_refs[$left_index_name] ??= dict[]; + $index_refs[$left_index_name][$left_index_key] = $left_mappings[$left_index_pk]; + } + } + } + + foreach ($left_dataset[2] as $left_index_name => $left_index_refs) { + foreach ($left_index_refs as $left_index_key => $left_index_pks) { + foreach ($left_index_pks as $left_index_pk) { + if (isset($left_mappings[$left_index_pk])) { + $index_refs[$left_index_name] ??= dict[]; + $index_refs[$left_index_name][$left_index_key] = $left_mappings[$left_index_pk]; + } + } + } + } + + foreach ($right_dataset[1] as $right_index_name => $right_index_refs) { + foreach ($right_index_refs as $right_index_key => $right_index_pk) { + if (isset($right_mappings[$right_index_pk])) { + $index_refs[$right_index_name] ??= dict[]; + $index_refs[$right_index_name][$right_index_key] = $right_mappings[$right_index_pk]; + } + } + } + + foreach ($right_dataset[2] as $right_index_name => $right_index_refs) { + foreach ($right_index_refs as $right_index_key => $right_index_pks) { + foreach ($right_index_pks as $right_index_pk) { + if (isset($right_mappings[$right_index_pk])) { + $index_refs[$right_index_name] ??= dict[]; + $index_refs[$right_index_name][$right_index_key] = $right_mappings[$right_index_pk]; + } + } + } + } + + return $index_refs; } } diff --git a/src/Query/Query.php b/src/Query/Query.php index 502cdb0..c5bd1cc 100644 --- a/src/Query/Query.php +++ b/src/Query/Query.php @@ -48,7 +48,7 @@ protected function applyOrderBy(AsyncMysqlConnection $_conn, dataset $data): dat // allow all column expressions to fall through to the full row foreach ($order_by as $rule) { $expr = $rule['expression']; - if ($expr is ColumnExpression && $expr->tableName() === null) { + if ($expr is ColumnExpression && $expr->tableName === null) { $expr->allowFallthrough(); } } diff --git a/src/Query/SelectQuery.php b/src/Query/SelectQuery.php index a0e48d3..3e1aba2 100644 --- a/src/Query/SelectQuery.php +++ b/src/Query/SelectQuery.php @@ -76,12 +76,14 @@ public function execute(AsyncMysqlConnection $conn, ?row $_ = null): dataset { * The FROM clause of the query gets processed first, retrieving data from tables, executing subqueries, and handling joins * This is also where we build up the $columns list which is commonly used throughout the entire library to map column references to index_refs in this dataset */ - protected function applyFrom(AsyncMysqlConnection $conn): (dataset, unique_index_refs, index_refs, vec) { + protected function applyFrom( + AsyncMysqlConnection $conn, + ): (dataset, unique_index_refs, index_refs, vec, dict) { $from = $this->fromClause; if ($from === null) { // we put one empty row when there is no FROM so that queries like "SELECT 1" will return a row - return tuple(dict[0 => dict[]], dict[], dict[], vec[]); + return tuple(dict[0 => dict[]], dict[], dict[], vec[], dict[]); } return $from->process($conn, $this->sql); @@ -173,9 +175,9 @@ protected function applySelect(AsyncMysqlConnection $conn, dataset $data): datas } foreach ($row as $col => $val) { $parts = Str\split((string)$col, '.'); - if ($expr->tableName() is nonnull) { + if ($expr->tableName is nonnull) { list($col_table_name, $col_name) = $parts; - if ($col_table_name == $expr->tableName()) { + if ($col_table_name == $expr->tableName) { if (!C\contains_key($formatted_row, $col)) { $formatted_row[$col_name] = $val; }