Skip to content

Commit

Permalink
add an Enable variable
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaiamao committed Jan 25, 2017
1 parent 2d7a5bd commit 86c6bfd
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 235 deletions.
6 changes: 1 addition & 5 deletions executor/executor_simple.go
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/plan/statistics"
"github.com/pingcap/tidb/plan/statscache"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util"
Expand Down Expand Up @@ -312,10 +311,7 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error {

func (e *SimpleExec) executeFlushTable(s *ast.FlushTableStmt) error {
// TODO: A dummy implement
dom := sessionctx.GetDomain(e.ctx)
handle := dom.PrivilegeHandle()
err := handle.Update()
return errors.Trace(err)
return nil
}

func (e *SimpleExec) executeAnalyzeTable(s *ast.AnalyzeTableStmt) error {
Expand Down
23 changes: 0 additions & 23 deletions plan/logical_plan_builder.go
Expand Up @@ -928,11 +928,6 @@ func (b *planBuilder) buildDataSource(tn *ast.TableName) LogicalPlan {
}
p.self = p
p.initIDAndContext(b.ctx)
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.SelectPriv,
db: schemaName.L,
table: tableInfo.Name.L,
})
// Equal condition contains a column from previous joined table.
schema := expression.NewSchema(make([]*expression.Column, 0, len(tableInfo.Columns))...)
for i, col := range tableInfo.Columns {
Expand Down Expand Up @@ -1071,15 +1066,6 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) LogicalPlan {
if b.err != nil {
return nil
}

if ds, ok := p.(*DataSource); ok {
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.UpdatePriv,
db: ds.DBName.L,
table: ds.tableInfo.Name.L,
})
}

_, _ = b.resolveHavingAndOrderBy(sel, p)
if sel.Where != nil {
p = b.buildSelection(p, sel.Where, nil)
Expand Down Expand Up @@ -1148,15 +1134,6 @@ func (b *planBuilder) buildDelete(delete *ast.DeleteStmt) LogicalPlan {
if b.err != nil {
return nil
}

if ds, ok := p.(*DataSource); ok {
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.DeletePriv,
db: ds.DBName.L,
table: ds.tableInfo.Name.L,
})
}

_, _ = b.resolveHavingAndOrderBy(sel, p)
if sel.Where != nil {
p = b.buildSelection(p, sel.Where, nil)
Expand Down
135 changes: 0 additions & 135 deletions plan/logical_plan_test.go
Expand Up @@ -15,7 +15,6 @@ package plan

import (
"fmt"
"sort"
"testing"

. "github.com/pingcap/check"
Expand Down Expand Up @@ -1617,137 +1616,3 @@ func (s *testPlanSuite) TestAggPrune(c *C) {
c.Assert(ToString(p), Equals, ca.best, comment)
}
}

func (s *testPlanSuite) TestVisitInfo(c *C) {
defer testleak.AfterTest(c)()
cases := []struct {
sql string
ans []visitInfo
}{
{
sql: "insert into t values (1)",
ans: []visitInfo{
{mysql.InsertPriv, "test", "t", ""},
},
},
{
sql: "delete from t where a = 1",
ans: []visitInfo{
{mysql.DeletePriv, "test", "t", ""},
{mysql.SelectPriv, "test", "t", ""},
},
},
{
sql: "update t set a = 7 where a = 1",
ans: []visitInfo{
{mysql.UpdatePriv, "test", "t", ""},
{mysql.SelectPriv, "test", "t", ""},
},
},
{
sql: "select a, sum(e) from t group by a",
ans: []visitInfo{
{mysql.SelectPriv, "test", "t", ""},
},
},
{
sql: "truncate table t",
ans: []visitInfo{
{mysql.DeletePriv, "test", "t", ""},
},
},
{
sql: "drop table t",
ans: []visitInfo{
{mysql.DropPriv, "test", "t", ""},
},
},
{
sql: "create table t (a int)",
ans: []visitInfo{
{mysql.CreatePriv, "test", "t", ""},
},
},
{
sql: "create database test",
ans: []visitInfo{
{mysql.CreatePriv, "test", "", ""},
},
},
{
sql: "drop database test",
ans: []visitInfo{
{mysql.DropPriv, "test", "", ""},
},
},
{
sql: "create index t_1 on t (a)",
ans: []visitInfo{
{mysql.IndexPriv, "test", "t", ""},
},
},
{
sql: "drop index e on t",
ans: []visitInfo{
{mysql.IndexPriv, "test", "t", ""},
},
},
}

for _, ca := range cases {
comment := Commentf("for %s", ca.sql)
stmt, err := s.ParseOneStmt(ca.sql, "", "")
c.Assert(err, IsNil, comment)

is, err := mockResolve(stmt)
c.Assert(err, IsNil)

builder := &planBuilder{
colMapper: make(map[*ast.ColumnNameExpr]int),
allocator: new(idAllocator),
ctx: mockContext(),
is: is,
}
builder.build(stmt)
c.Assert(builder.err, IsNil, comment)

checkVisitInfo(c, builder.visitInfo, ca.ans, comment)
}
}

type visitInfoArray []visitInfo

func (v visitInfoArray) Len() int {
return len(v)
}

func (v visitInfoArray) Less(i, j int) bool {
if v[i].privilege < v[j].privilege {
return true
}
if v[i].db < v[j].db {
return true
}
if v[i].table < v[j].table {
return true
}
if v[i].column < v[j].column {
return true
}

return false
}

func (v visitInfoArray) Swap(i, j int) {
v[i], v[j] = v[j], v[i]
}

func checkVisitInfo(c *C, v1, v2 []visitInfo, comment CommentInterface) {
sort.Sort(visitInfoArray(v1))
sort.Sort(visitInfoArray(v2))

c.Assert(len(v1), Equals, len(v2), comment)
for i := 0; i < len(v1); i++ {
c.Assert(v1[i], Equals, v2[i], comment)
}
}
91 changes: 20 additions & 71 deletions plan/planbuilder.go
Expand Up @@ -84,10 +84,24 @@ func (b *planBuilder) build(node ast.Node) Plan {
switch x := node.(type) {
case *ast.AdminStmt:
return b.buildAdmin(x)
case *ast.AlterTableStmt:
return b.buildDDL(x)
case *ast.CreateDatabaseStmt:
return b.buildDDL(x)
case *ast.CreateIndexStmt:
return b.buildDDL(x)
case *ast.CreateTableStmt:
return b.buildDDL(x)
case *ast.DeallocateStmt:
return &Deallocate{Name: x.Name}
case *ast.DeleteStmt:
return b.buildDelete(x)
case *ast.DropDatabaseStmt:
return b.buildDDL(x)
case *ast.DropIndexStmt:
return b.buildDDL(x)
case *ast.DropTableStmt:
return b.buildDDL(x)
case *ast.ExecuteStmt:
return b.buildExecute(x)
case *ast.ExplainStmt:
Expand All @@ -114,7 +128,9 @@ func (b *planBuilder) build(node ast.Node) Plan {
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt,
*ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt:
return b.buildSimple(node.(ast.StmtNode))
case ast.DDLNode:
case *ast.TruncateTableStmt:
return b.buildDDL(x)
case *ast.RenameTableStmt:
return b.buildDDL(x)
}
b.err = ErrUnsupportedType.Gen("Unsupported type %T", node)
Expand Down Expand Up @@ -501,13 +517,6 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
Ignore: insert.Ignore,
baseLogicalPlan: newBaseLogicalPlan(Ins, b.allocator),
}

b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.InsertPriv,
db: tn.DBInfo.Name.L,
table: tableInfo.Name.L,
})

cols := table.Cols()
for _, valuesItem := range insert.Lists {
exprList := make([]expression.Expression, 0, len(valuesItem))
Expand Down Expand Up @@ -605,69 +614,9 @@ func (b *planBuilder) buildLoadData(ld *ast.LoadDataStmt) Plan {
}

func (b *planBuilder) buildDDL(node ast.DDLNode) Plan {
switch v := node.(type) {
case *ast.AlterTableStmt:
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.AlterPriv,
db: v.Table.Schema.L,
table: v.Table.Name.L,
})
case *ast.CreateDatabaseStmt:
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.CreatePriv,
db: v.Name,
})
case *ast.CreateIndexStmt:
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.IndexPriv,
db: v.Table.Schema.L,
table: v.Table.Name.L,
})
case *ast.CreateTableStmt:
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.CreatePriv,
db: v.Table.Schema.L,
table: v.Table.Name.L,
})
case *ast.DropDatabaseStmt:
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.DropPriv,
db: v.Name,
})
case *ast.DropIndexStmt:
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.IndexPriv,
db: v.Table.Schema.L,
table: v.Table.Name.L,
})
case *ast.DropTableStmt:
for _, table := range v.Tables {
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.DropPriv,
db: table.Schema.L,
table: table.Name.L,
})
}
case *ast.TruncateTableStmt:
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.DeletePriv,
db: v.Table.Schema.L,
table: v.Table.Name.L,
})
case *ast.RenameTableStmt:
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.AlterPriv,
db: v.OldTable.Schema.L,
table: v.OldTable.Name.L,
})
b.visitInfo = append(b.visitInfo, visitInfo{
privilege: mysql.AlterPriv,
db: v.NewTable.Schema.L,
table: v.NewTable.Name.L,
})
}

return &DDL{Statement: node}
p := &DDL{Statement: node}
p.SetSchema(expression.NewSchema())
return p
}

func (b *planBuilder) buildExplain(explain *ast.ExplainStmt) Plan {
Expand Down
1 change: 0 additions & 1 deletion plan/preprocess.go
Expand Up @@ -25,6 +25,5 @@ func Preprocess(node ast.Node, info infoschema.InfoSchema, ctx context.Context)
if err := ResolveName(node, info, ctx); err != nil {
return errors.Trace(err)
}

return nil
}
1 change: 1 addition & 0 deletions privilege/privileges/cache_test.go
Expand Up @@ -29,6 +29,7 @@ type testCacheSuite struct {
}

func (s *testCacheSuite) SetUpSuite(c *C) {
Enable = true
store, err := tidb.NewStore("memory://mysql")
c.Assert(err, IsNil)
err = tidb.BootstrapSession(store)
Expand Down
7 changes: 7 additions & 0 deletions privilege/privileges/privileges.go
Expand Up @@ -29,6 +29,9 @@ import (
"github.com/pingcap/tidb/util/types"
)

// Enable enables the new privilege check feature.
var Enable bool = false

// privilege error codes.
const (
codeInvalidPrivilegeType terror.ErrCode = 1
Expand Down Expand Up @@ -177,6 +180,10 @@ type UserPrivileges struct {

// RequestVerification implements the Checker interface.
func (p *UserPrivileges) RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool {
if !Enable {
return true
}

if p.User == "" {
return true
}
Expand Down
1 change: 1 addition & 0 deletions privilege/privileges/privileges_test.go
Expand Up @@ -54,6 +54,7 @@ type testPrivilegeSuite struct {
}

func (s *testPrivilegeSuite) SetUpSuit(c *C) {
Enable = true
logLevel := os.Getenv("log_level")
log.SetLevelByString(logLevel)
}
Expand Down

0 comments on commit 86c6bfd

Please sign in to comment.