Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
GetSchema: GetSchema: Batch/parallel access to underlying mysqld for …
Browse files Browse the repository at this point in the history
…lower latency

Signed-off-by: Toliver Jue <toliver@planetscale.com>
  • Loading branch information
Toliver Jue committed Aug 17, 2020
1 parent e16a6c5 commit fce4cfd
Showing 1 changed file with 173 additions and 77 deletions.
250 changes: 173 additions & 77 deletions go/vt/mysqlctl/schema.go
Expand Up @@ -17,10 +17,15 @@ limitations under the License.
package mysqlctl

import (
"errors"
"fmt"
"regexp"
"strings"
"sync"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/concurrency"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/evalengine"

"golang.org/x/net/context"
Expand All @@ -46,6 +51,26 @@ func (mysqld *Mysqld) executeSchemaCommands(sql string) error {
return mysqld.executeMysqlScript(params, strings.NewReader(sql))
}

func encodeTableName(tableName string) string {
var buf strings.Builder
sqltypes.NewVarChar(tableName).EncodeSQL(&buf)
return buf.String()
}

// tableListSql returns an IN clause "('t1', 't2'...) for a list of tables."
func tableListSql(tables []string) (string, error) {
if len(tables) == 0 {
return "", errors.New("no tables for tableListSql")
}

encodedTables := make([]string, len(tables))
for i, tableName := range tables {
encodedTables[i] = encodeTableName(tableName)
}

return "(" + strings.Join(encodedTables, ", ") + ")", nil
}

// GetSchema returns the schema for database for tables listed in
// tables. If tables is empty, return the schema for all tables.
func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excludeTables []string, includeViews bool) (*tabletmanagerdatapb.SchemaDefinition, error) {
Expand All @@ -62,6 +87,76 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl
}
sd.DatabaseSchema = strings.Replace(qr.Rows[0][1].ToString(), backtickDBName, "{{.DatabaseName}}", 1)

tds, err := mysqld.collectBasicTableData(ctx, dbName, tables, excludeTables, includeViews)
if err != nil {
return nil, err
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

var wg sync.WaitGroup
allErrors := &concurrency.AllErrorRecorder{}

// Get per-table schema concurrently.
tableNames := make([]string, 0, len(tds))
for _, td := range tds {
tableNames = append(tableNames, td.Name)

wg.Add(1)
go func(td *tabletmanagerdatapb.TableDefinition) {
defer wg.Done()

fields, columns, schema, err := mysqld.collectSchema(ctx, dbName, td.Name, td.Type)
if err != nil {
allErrors.RecordError(err)
cancel()
return
}

td.Fields = fields
td.Columns = columns
td.Schema = schema
}(td)
}

// Get primary columns concurrently.
colMap := map[string][]string{}
if len(tableNames) > 0 {
wg.Add(1)
go func() {
defer wg.Done()

log.Infof("mysqld GetSchema: GetPrimaryKeyColumns")
var err error
colMap, err = mysqld.getPrimaryKeyColumns(ctx, dbName, tableNames...)
if err != nil {
allErrors.RecordError(err)
cancel()
return
}
log.Infof("mysqld GetSchema: GetPrimaryKeyColumns done")
}()
}

wg.Wait()
if err := allErrors.AggrError(vterrors.Aggregate); err != nil {
return nil, err
}

log.Infof("mysqld GetSchema: Collecting all table schemas")
for _, td := range tds {
td.PrimaryKeyColumns = colMap[td.Name]
}
log.Infof("mysqld GetSchema: Collecting all table schemas done")

sd.TableDefinitions = tds

tmutils.GenerateSchemaVersion(sd)
return sd, nil
}

func (mysqld *Mysqld) collectBasicTableData(ctx context.Context, dbName string, tables, excludeTables []string, includeViews bool) ([]*tabletmanagerdatapb.TableDefinition, error) {
// get the list of tables we're interested in
sql := "SELECT table_name, table_type, data_length, table_rows FROM information_schema.tables WHERE table_schema = '" + dbName + "'"
if !includeViews {
Expand All @@ -72,14 +167,23 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl
return nil, err
}
if len(qr.Rows) == 0 {
return sd, nil
return nil, nil
}

filter, err := tmutils.NewTableFilter(tables, excludeTables, includeViews)
if err != nil {
return nil, err
}

sd.TableDefinitions = make([]*tabletmanagerdatapb.TableDefinition, 0, len(qr.Rows))
tds := make([]*tabletmanagerdatapb.TableDefinition, 0, len(qr.Rows))
for _, row := range qr.Rows {
tableName := row[0].ToString()
tableType := row[1].ToString()

if !filter.Includes(tableName, tableType) {
continue
}

// compute dataLength
var dataLength uint64
if !row[2].IsNull() {
Expand All @@ -99,49 +203,53 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl
}
}

qr, fetchErr := mysqld.FetchSuperQuery(ctx, fmt.Sprintf("SHOW CREATE TABLE %s.%s", backtickDBName, sqlescape.EscapeID(tableName)))
if fetchErr != nil {
return nil, fetchErr
}
if len(qr.Rows) == 0 {
return nil, fmt.Errorf("empty create table statement for %v", tableName)
}

// Normalize & remove auto_increment because it changes on every insert
// FIXME(alainjobart) find a way to share this with
// vt/tabletserver/table_info.go:162
norm := qr.Rows[0][1].ToString()
norm = autoIncr.ReplaceAllLiteralString(norm, "")
if tableType == tmutils.TableView {
// Views will have the dbname in there, replace it
// with {{.DatabaseName}}
norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1)
}
tds = append(tds, &tabletmanagerdatapb.TableDefinition{
Name: tableName,
Type: tableType,
DataLength: dataLength,
RowCount: rowCount,
})
}

td := &tabletmanagerdatapb.TableDefinition{}
td.Name = tableName
td.Schema = norm
return tds, nil
}

td.Fields, td.Columns, err = mysqld.GetColumns(ctx, dbName, tableName)
if err != nil {
return nil, err
}
td.PrimaryKeyColumns, err = mysqld.GetPrimaryKeyColumns(ctx, dbName, tableName)
if err != nil {
return nil, err
}
td.Type = tableType
td.DataLength = dataLength
td.RowCount = rowCount
sd.TableDefinitions = append(sd.TableDefinitions, td)
func (mysqld *Mysqld) collectSchema(ctx context.Context, dbName, tableName, tableType string) ([]*querypb.Field, []string, string, error) {
fields, columns, err := mysqld.GetColumns(ctx, dbName, tableName)
if err != nil {
return nil, nil, "", err
}

sd, err = tmutils.FilterTables(sd, tables, excludeTables, includeViews)
schema, err := mysqld.normalizedSchema(ctx, dbName, tableName, tableType)
if err != nil {
return nil, err
return nil, nil, "", err
}
tmutils.GenerateSchemaVersion(sd)
return sd, nil

return fields, columns, schema, nil
}

func (mysqld *Mysqld) normalizedSchema(ctx context.Context, dbName, tableName, tableType string) (string, error) {
backtickDBName := sqlescape.EscapeID(dbName)
qr, fetchErr := mysqld.FetchSuperQuery(ctx, fmt.Sprintf("SHOW CREATE TABLE %s.%s", dbName, sqlescape.EscapeID(tableName)))
if fetchErr != nil {
return "", fetchErr
}
if len(qr.Rows) == 0 {
return "", fmt.Errorf("empty create table statement for %v", tableName)
}

// Normalize & remove auto_increment because it changes on every insert
// FIXME(alainjobart) find a way to share this with
// vt/tabletserver/table_info.go:162
norm := qr.Rows[0][1].ToString()
norm = autoIncr.ReplaceAllLiteralString(norm, "")
if tableType == tmutils.TableView {
// Views will have the dbname in there, replace it
// with {{.DatabaseName}}
norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1)
}

return norm, nil
}

// ResolveTables returns a list of actual tables+views matching a list
Expand All @@ -166,11 +274,11 @@ func (mysqld *Mysqld) GetColumns(ctx context.Context, dbName, table string) ([]*
}
defer conn.Recycle()

sql := fmt.Sprintf("SELECT * FROM %s.%s WHERE 1=0", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table))
qr, err := mysqld.executeFetchContext(ctx, conn, sql, 0, true)
qr, err := conn.ExecuteFetch(fmt.Sprintf("SELECT * FROM %s.%s WHERE 1=0", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table)), 0, true)
if err != nil {
return nil, nil, err
}

columns := make([]string, len(qr.Fields))
for i, field := range qr.Fields {
columns[i] = field.Name
Expand All @@ -181,55 +289,43 @@ func (mysqld *Mysqld) GetColumns(ctx context.Context, dbName, table string) ([]*

// GetPrimaryKeyColumns returns the primary key columns of table.
func (mysqld *Mysqld) GetPrimaryKeyColumns(ctx context.Context, dbName, table string) ([]string, error) {
cs, err := mysqld.getPrimaryKeyColumns(ctx, dbName, table)
if err != nil {
return nil, err
}

return cs[dbName], nil
}

func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, tables ...string) (map[string][]string, error) {
conn, err := getPoolReconnect(ctx, mysqld.dbaPool)
if err != nil {
return nil, err
}
defer conn.Recycle()

sql := fmt.Sprintf("SHOW INDEX FROM %s.%s", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table))
qr, err := mysqld.executeFetchContext(ctx, conn, sql, 100, true)
tableList, err := tableListSql(tables)
if err != nil {
return nil, err
}
keyNameIndex := -1
seqInIndexIndex := -1
columnNameIndex := -1
for i, field := range qr.Fields {
switch field.Name {
case "Key_name":
keyNameIndex = i
case "Seq_in_index":
seqInIndexIndex = i
case "Column_name":
columnNameIndex = i
}
}
if keyNameIndex == -1 || seqInIndexIndex == -1 || columnNameIndex == -1 {
return nil, fmt.Errorf("unknown columns in 'show index' result: %v", qr.Fields)
sql := fmt.Sprintf(`
SELECT table_name, ordinal_position, column_name
FROM information_schema.key_column_usage
WHERE table_schema = '%s'
AND table_name IN %s
AND constraint_name='PRIMARY'
ORDER BY table_name, ordinal_position`, dbName, tableList)
qr, err := conn.ExecuteFetch(sql, len(tables)*100, true)
if err != nil {
return nil, err
}

columns := make([]string, 0, 5)
var expectedIndex int64 = 1
colMap := map[string][]string{}
for _, row := range qr.Rows {
// skip non-primary keys
if row[keyNameIndex].ToString() != "PRIMARY" {
continue
}

// check the Seq_in_index is always increasing
seqInIndex, err := evalengine.ToInt64(row[seqInIndexIndex])
if err != nil {
return nil, err
}
if seqInIndex != expectedIndex {
return nil, fmt.Errorf("unexpected index: %v != %v", seqInIndex, expectedIndex)
}
expectedIndex++

columns = append(columns, row[columnNameIndex].ToString())
tableName := row[0].ToString()
colMap[tableName] = append(colMap[tableName], row[2].ToString())
}
return columns, err
return colMap, err
}

// PreflightSchemaChange checks the schema changes in "changes" by applying them
Expand Down

0 comments on commit fce4cfd

Please sign in to comment.