diff --git a/executor/simple.go b/executor/simple.go index 79fd9110f3570..4e4e944052f9d 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" @@ -287,14 +288,22 @@ func userExists(ctx sessionctx.Context, name string, host string) (bool, error) } func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { + var u, h string if s.User == nil { - vars := e.ctx.GetSessionVars() - s.User = vars.User - if s.User == nil { + if e.ctx.GetSessionVars().User == nil { return errors.New("Session error is empty") } + u = e.ctx.GetSessionVars().User.AuthUsername + h = e.ctx.GetSessionVars().User.AuthHostname + } else { + checker := privilege.GetPrivilegeManager(e.ctx) + if checker != nil && !checker.RequestVerification("", "", "", mysql.SuperPriv) { + return ErrDBaccessDenied.GenWithStackByArgs(u, h, "mysql") + } + u = s.User.Username + h = s.User.Hostname } - exists, err := userExists(e.ctx, s.User.Username, s.User.Hostname) + exists, err := userExists(e.ctx, u, h) if err != nil { return errors.Trace(err) } @@ -303,7 +312,7 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { } // update mysql.user - sql := fmt.Sprintf(`UPDATE %s.%s SET password="%s" WHERE User="%s" AND Host="%s";`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), s.User.Username, s.User.Hostname) + sql := fmt.Sprintf(`UPDATE %s.%s SET password="%s" WHERE User="%s" AND Host="%s";`, mysql.SystemDB, mysql.UserTable, auth.EncodePassword(s.Password), u, h) _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql) domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) return errors.Trace(err) diff --git a/executor/simple_test.go b/executor/simple_test.go index ffd819d98d845..af0106c90adf2 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -219,15 +219,16 @@ func (s *testSuite) TestSetPwd(c *C) { tk.Se, err = session.CreateSession4Test(s.store) c.Check(err, IsNil) ctx := tk.Se.(sessionctx.Context) - ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testpwd1", Hostname: "localhost"} + ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testpwd1", Hostname: "localhost", AuthUsername: "testpwd1", AuthHostname: "localhost"} // Session user doesn't exist. _, err = tk.Exec(setPwdSQL) c.Check(terror.ErrorEqual(err, executor.ErrPasswordNoMatch), IsTrue, Commentf("err %v", err)) // normal - ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testpwd", Hostname: "localhost"} + ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testpwd", Hostname: "localhost", AuthUsername: "testpwd", AuthHostname: "localhost"} tk.MustExec(setPwdSQL) result = tk.MustQuery(`SELECT Password FROM mysql.User WHERE User="testpwd" and Host="localhost"`) result.Check(testkit.Rows(auth.EncodePassword("pwd"))) + } func (s *testSuite) TestKillStmt(c *C) { diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index c76d55e7dc973..939214e965bd6 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1709,9 +1709,7 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { }, { sql: `set password for 'root'@'%' = 'xxxxx'`, - ans: []visitInfo{ - {mysql.SuperPriv, "", "", ""}, - }, + ans: []visitInfo{}, }, { sql: `show create table test.ttt`, diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 2a9b494132b03..c64de57ec8847 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -961,7 +961,7 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) Plan { b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "") case *ast.GrantStmt: b.visitInfo = collectVisitInfoFromGrantStmt(b.visitInfo, raw) - case *ast.SetPwdStmt, *ast.RevokeStmt: + case *ast.RevokeStmt: b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "") case *ast.KillStmt: // If you have the SUPER privilege, you can kill all threads and statements. diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 48cebbe1bae9f..8b0c7dc1687d0 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -274,6 +274,26 @@ func (s *testPrivilegeSuite) TestDropTablePriv(c *C) { mustExec(c, se, `DROP TABLE todrop;`) } +func (s *testPrivilegeSuite) TestSetPasswdStmt(c *C) { + + se := newSession(c, s.store, s.dbName) + + // high privileged user setting password for other user (passes) + mustExec(c, se, "CREATE USER 'superuser'") + mustExec(c, se, "CREATE USER 'nobodyuser'") + mustExec(c, se, "GRANT ALL ON *.* TO 'superuser'") + mustExec(c, se, "FLUSH PRIVILEGES") + + c.Assert(se.Auth(&auth.UserIdentity{Username: "superuser", Hostname: "localhost", AuthUsername: "superuser", AuthHostname: "%"}, nil, nil), IsTrue) + mustExec(c, se, "SET PASSWORD for 'nobodyuser' = 'newpassword'") + + // low privileged user trying to set password for other user (fails) + c.Assert(se.Auth(&auth.UserIdentity{Username: "nobodyuser", Hostname: "localhost", AuthUsername: "nobodyuser", AuthHostname: "%"}, nil, nil), IsTrue) + _, err := se.Execute(context.Background(), "SET PASSWORD for 'superuser' = 'newpassword'") + c.Assert(err, NotNil) + +} + func (s *testPrivilegeSuite) TestCheckAuthenticate(c *C) { se := newSession(c, s.store, s.dbName)