diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index 7d57c5488cb..bd88dc6f219 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -17,10 +17,15 @@ limitations under the License. package mysqlctl import ( + "errors" "fmt" "regexp" "strings" + "sync" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/concurrency" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/evalengine" "golang.org/x/net/context" @@ -46,6 +51,26 @@ func (mysqld *Mysqld) executeSchemaCommands(sql string) error { return mysqld.executeMysqlScript(params, strings.NewReader(sql)) } +func encodeTableName(tableName string) string { + var buf strings.Builder + sqltypes.NewVarChar(tableName).EncodeSQL(&buf) + return buf.String() +} + +// tableListSql returns an IN clause "('t1', 't2'...) for a list of tables." +func tableListSql(tables []string) (string, error) { + if len(tables) == 0 { + return "", errors.New("no tables for tableListSql") + } + + encodedTables := make([]string, len(tables)) + for i, tableName := range tables { + encodedTables[i] = encodeTableName(tableName) + } + + return "(" + strings.Join(encodedTables, ", ") + ")", nil +} + // GetSchema returns the schema for database for tables listed in // tables. If tables is empty, return the schema for all tables. func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excludeTables []string, includeViews bool) (*tabletmanagerdatapb.SchemaDefinition, error) { @@ -62,6 +87,76 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl } sd.DatabaseSchema = strings.Replace(qr.Rows[0][1].ToString(), backtickDBName, "{{.DatabaseName}}", 1) + tds, err := mysqld.collectBasicTableData(ctx, dbName, tables, excludeTables, includeViews) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + allErrors := &concurrency.AllErrorRecorder{} + + // Get per-table schema concurrently. + tableNames := make([]string, 0, len(tds)) + for _, td := range tds { + tableNames = append(tableNames, td.Name) + + wg.Add(1) + go func(td *tabletmanagerdatapb.TableDefinition) { + defer wg.Done() + + fields, columns, schema, err := mysqld.collectSchema(ctx, dbName, td.Name, td.Type) + if err != nil { + allErrors.RecordError(err) + cancel() + return + } + + td.Fields = fields + td.Columns = columns + td.Schema = schema + }(td) + } + + // Get primary columns concurrently. + colMap := map[string][]string{} + if len(tableNames) > 0 { + wg.Add(1) + go func() { + defer wg.Done() + + log.Infof("mysqld GetSchema: GetPrimaryKeyColumns") + var err error + colMap, err = mysqld.getPrimaryKeyColumns(ctx, dbName, tableNames...) + if err != nil { + allErrors.RecordError(err) + cancel() + return + } + log.Infof("mysqld GetSchema: GetPrimaryKeyColumns done") + }() + } + + wg.Wait() + if err := allErrors.AggrError(vterrors.Aggregate); err != nil { + return nil, err + } + + log.Infof("mysqld GetSchema: Collecting all table schemas") + for _, td := range tds { + td.PrimaryKeyColumns = colMap[td.Name] + } + log.Infof("mysqld GetSchema: Collecting all table schemas done") + + sd.TableDefinitions = tds + + tmutils.GenerateSchemaVersion(sd) + return sd, nil +} + +func (mysqld *Mysqld) collectBasicTableData(ctx context.Context, dbName string, tables, excludeTables []string, includeViews bool) ([]*tabletmanagerdatapb.TableDefinition, error) { // get the list of tables we're interested in sql := "SELECT table_name, table_type, data_length, table_rows FROM information_schema.tables WHERE table_schema = '" + dbName + "'" if !includeViews { @@ -72,14 +167,23 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl return nil, err } if len(qr.Rows) == 0 { - return sd, nil + return nil, nil + } + + filter, err := tmutils.NewTableFilter(tables, excludeTables, includeViews) + if err != nil { + return nil, err } - sd.TableDefinitions = make([]*tabletmanagerdatapb.TableDefinition, 0, len(qr.Rows)) + tds := make([]*tabletmanagerdatapb.TableDefinition, 0, len(qr.Rows)) for _, row := range qr.Rows { tableName := row[0].ToString() tableType := row[1].ToString() + if !filter.Includes(tableName, tableType) { + continue + } + // compute dataLength var dataLength uint64 if !row[2].IsNull() { @@ -99,49 +203,53 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl } } - qr, fetchErr := mysqld.FetchSuperQuery(ctx, fmt.Sprintf("SHOW CREATE TABLE %s.%s", backtickDBName, sqlescape.EscapeID(tableName))) - if fetchErr != nil { - return nil, fetchErr - } - if len(qr.Rows) == 0 { - return nil, fmt.Errorf("empty create table statement for %v", tableName) - } - - // Normalize & remove auto_increment because it changes on every insert - // FIXME(alainjobart) find a way to share this with - // vt/tabletserver/table_info.go:162 - norm := qr.Rows[0][1].ToString() - norm = autoIncr.ReplaceAllLiteralString(norm, "") - if tableType == tmutils.TableView { - // Views will have the dbname in there, replace it - // with {{.DatabaseName}} - norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1) - } + tds = append(tds, &tabletmanagerdatapb.TableDefinition{ + Name: tableName, + Type: tableType, + DataLength: dataLength, + RowCount: rowCount, + }) + } - td := &tabletmanagerdatapb.TableDefinition{} - td.Name = tableName - td.Schema = norm + return tds, nil +} - td.Fields, td.Columns, err = mysqld.GetColumns(ctx, dbName, tableName) - if err != nil { - return nil, err - } - td.PrimaryKeyColumns, err = mysqld.GetPrimaryKeyColumns(ctx, dbName, tableName) - if err != nil { - return nil, err - } - td.Type = tableType - td.DataLength = dataLength - td.RowCount = rowCount - sd.TableDefinitions = append(sd.TableDefinitions, td) +func (mysqld *Mysqld) collectSchema(ctx context.Context, dbName, tableName, tableType string) ([]*querypb.Field, []string, string, error) { + fields, columns, err := mysqld.GetColumns(ctx, dbName, tableName) + if err != nil { + return nil, nil, "", err } - sd, err = tmutils.FilterTables(sd, tables, excludeTables, includeViews) + schema, err := mysqld.normalizedSchema(ctx, dbName, tableName, tableType) if err != nil { - return nil, err + return nil, nil, "", err } - tmutils.GenerateSchemaVersion(sd) - return sd, nil + + return fields, columns, schema, nil +} + +func (mysqld *Mysqld) normalizedSchema(ctx context.Context, dbName, tableName, tableType string) (string, error) { + backtickDBName := sqlescape.EscapeID(dbName) + qr, fetchErr := mysqld.FetchSuperQuery(ctx, fmt.Sprintf("SHOW CREATE TABLE %s.%s", dbName, sqlescape.EscapeID(tableName))) + if fetchErr != nil { + return "", fetchErr + } + if len(qr.Rows) == 0 { + return "", fmt.Errorf("empty create table statement for %v", tableName) + } + + // Normalize & remove auto_increment because it changes on every insert + // FIXME(alainjobart) find a way to share this with + // vt/tabletserver/table_info.go:162 + norm := qr.Rows[0][1].ToString() + norm = autoIncr.ReplaceAllLiteralString(norm, "") + if tableType == tmutils.TableView { + // Views will have the dbname in there, replace it + // with {{.DatabaseName}} + norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1) + } + + return norm, nil } // ResolveTables returns a list of actual tables+views matching a list @@ -166,11 +274,11 @@ func (mysqld *Mysqld) GetColumns(ctx context.Context, dbName, table string) ([]* } defer conn.Recycle() - sql := fmt.Sprintf("SELECT * FROM %s.%s WHERE 1=0", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table)) - qr, err := mysqld.executeFetchContext(ctx, conn, sql, 0, true) + qr, err := conn.ExecuteFetch(fmt.Sprintf("SELECT * FROM %s.%s WHERE 1=0", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table)), 0, true) if err != nil { return nil, nil, err } + columns := make([]string, len(qr.Fields)) for i, field := range qr.Fields { columns[i] = field.Name @@ -181,55 +289,43 @@ func (mysqld *Mysqld) GetColumns(ctx context.Context, dbName, table string) ([]* // GetPrimaryKeyColumns returns the primary key columns of table. func (mysqld *Mysqld) GetPrimaryKeyColumns(ctx context.Context, dbName, table string) ([]string, error) { + cs, err := mysqld.getPrimaryKeyColumns(ctx, dbName, table) + if err != nil { + return nil, err + } + + return cs[dbName], nil +} + +func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, tables ...string) (map[string][]string, error) { conn, err := getPoolReconnect(ctx, mysqld.dbaPool) if err != nil { return nil, err } defer conn.Recycle() - sql := fmt.Sprintf("SHOW INDEX FROM %s.%s", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table)) - qr, err := mysqld.executeFetchContext(ctx, conn, sql, 100, true) + tableList, err := tableListSql(tables) if err != nil { return nil, err } - keyNameIndex := -1 - seqInIndexIndex := -1 - columnNameIndex := -1 - for i, field := range qr.Fields { - switch field.Name { - case "Key_name": - keyNameIndex = i - case "Seq_in_index": - seqInIndexIndex = i - case "Column_name": - columnNameIndex = i - } - } - if keyNameIndex == -1 || seqInIndexIndex == -1 || columnNameIndex == -1 { - return nil, fmt.Errorf("unknown columns in 'show index' result: %v", qr.Fields) + sql := fmt.Sprintf(` + SELECT table_name, ordinal_position, column_name + FROM information_schema.key_column_usage + WHERE table_schema = '%s' + AND table_name IN %s + AND constraint_name='PRIMARY' + ORDER BY table_name, ordinal_position`, dbName, tableList) + qr, err := conn.ExecuteFetch(sql, len(tables)*100, true) + if err != nil { + return nil, err } - columns := make([]string, 0, 5) - var expectedIndex int64 = 1 + colMap := map[string][]string{} for _, row := range qr.Rows { - // skip non-primary keys - if row[keyNameIndex].ToString() != "PRIMARY" { - continue - } - - // check the Seq_in_index is always increasing - seqInIndex, err := evalengine.ToInt64(row[seqInIndexIndex]) - if err != nil { - return nil, err - } - if seqInIndex != expectedIndex { - return nil, fmt.Errorf("unexpected index: %v != %v", seqInIndex, expectedIndex) - } - expectedIndex++ - - columns = append(columns, row[columnNameIndex].ToString()) + tableName := row[0].ToString() + colMap[tableName] = append(colMap[tableName], row[2].ToString()) } - return columns, err + return colMap, err } // PreflightSchemaChange checks the schema changes in "changes" by applying them