Skip to content

Commit

Permalink
database,web: Split GetUserByRemoteIdentity method
Browse files Browse the repository at this point in the history
Fixes #10 and Fixes #12.
  • Loading branch information
s111 committed Jul 5, 2017
1 parent 6084745 commit 7de1163
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 113 deletions.
5 changes: 4 additions & 1 deletion database/database.go
Expand Up @@ -4,9 +4,12 @@ import "github.com/autograde/aguis/models"

// Database contains methods for manipulating the database.
type Database interface {
NewUserFromRemoteIdentity(provider string, remoteID uint64, accessToken string) (*models.User, error)
AssociateUserWithRemoteIdentity(userID uint64, provider string, remoteID uint64, accessToken string) error

GetUser(uint64) (*models.User, error)
GetUserByRemoteIdentity(provider string, id uint64, accessToken string) (*models.User, error)
GetUsers() (*[]models.User, error)
GetUserByRemoteIdentity(string, uint64, string) (*models.User, error)

CreateCourse(*models.Course) error
GetCourses() (*[]models.Course, error)
Expand Down
66 changes: 39 additions & 27 deletions database/gormdb.go
Expand Up @@ -36,60 +36,72 @@ func (db *GormDB) GetUser(id uint64) (*models.User, error) {
return &user, nil
}

// GetUsers implements the Database interface.
func (db *GormDB) GetUsers() (*[]models.User, error) {
var users []models.User
if err := db.conn.Find(&users).Error; err != nil {
return nil, err
}
return &users, nil
}

// GetUserByRemoteIdentity implements the Database interface.
func (db *GormDB) GetUserByRemoteIdentity(provider string, id uint64, accessToken string) (*models.User, error) {
tx := db.conn.Begin()

// Get the remote identity.
var remoteIdentity models.RemoteIdentity
if err := tx.
Where("provider = ? AND remote_id = ?", provider, id).
First(&remoteIdentity).Error; err == gorm.ErrRecordNotFound {
user := models.User{
RemoteIdentities: []models.RemoteIdentity{{
Provider: provider,
RemoteID: id,
AccessToken: accessToken,
}},
}
if err := tx.Create(&user).Error; err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Commit().Error; err != nil {
return nil, err
}
return &user, nil
} else if err != nil {
First(&remoteIdentity).Error; err != nil {
tx.Rollback()
return nil, err
}

// Update the access token.
if err := tx.Model(&remoteIdentity).Update("access_token", accessToken).Error; err != nil {
tx.Rollback()
return nil, err
}

// Get the user.
var user models.User
if err := tx.Preload("RemoteIdentities").First(&user, remoteIdentity.UserID).Error; err != nil {
tx.Rollback()
return nil, err
}

if err := tx.Commit().Error; err != nil {
return nil, err
}
return &user, nil
}

// GetUsers implements the Database interface.
func (db *GormDB) GetUsers() (*[]models.User, error) {
var users []models.User
if err := db.conn.Find(&users).Error; err != nil {
return nil, err
}
return &users, nil
}

// NewUserFromRemoteIdentity implements the Database interface.
func (db *GormDB) NewUserFromRemoteIdentity(provider string, remoteID uint64, accessToken string) (*models.User, error) {
user := models.User{
RemoteIdentities: []models.RemoteIdentity{{
Provider: provider,
RemoteID: remoteID,
AccessToken: accessToken,
}},
}
if err := db.conn.Create(&user).Error; err != nil {
return nil, err
}
return &user, nil
}

// AssociateUserWithRemoteIdentity implements the Database interface.
func (db *GormDB) AssociateUserWithRemoteIdentity(userID uint64, provider string, remoteID uint64, accessToken string) error {
remoteIdentity := models.RemoteIdentity{
Provider: provider,
RemoteID: remoteID,
AccessToken: accessToken,
UserID: userID,
}
return db.conn.Create(&remoteIdentity).Error
}

// CreateCourse implements the Database interface.
func (db *GormDB) CreateCourse(course *models.Course) error {
return db.conn.Create(course).Error
Expand Down
76 changes: 49 additions & 27 deletions database/gormdb_test.go
Expand Up @@ -61,40 +61,58 @@ func TestGormDBGetUsers(t *testing.T) {
}
}

func TestGormDBGetUserByRemoteIdentity(t *testing.T) {
func TestGormDBAssociateUserWithRemoteIdentity(t *testing.T) {
const (
initialToken = "123"
newToken = "ABC"
provider = "github"
remoteID = 10
uID = 1
rID1 = 1
rID2 = 2

secret1 = "123"
provider1 = "github"
remoteID1 = 10

secret2 = "ABC"
provider2 = "gitlab"
remoteID2 = 20
)

wantUser1 := &models.User{
ID: 1,
RemoteIdentities: []models.RemoteIdentity{{
ID: 1,
Provider: provider,
RemoteID: remoteID,
AccessToken: initialToken,
UserID: 1,
}},
}
var (
wantUser1 = &models.User{
ID: uID,
RemoteIdentities: []models.RemoteIdentity{{
ID: rID1,
Provider: provider1,
RemoteID: remoteID1,
AccessToken: secret1,
UserID: uID,
}},
}

wantUser2 := &models.User{
ID: 1,
RemoteIdentities: []models.RemoteIdentity{{
ID: 1,
Provider: provider,
RemoteID: remoteID,
AccessToken: newToken,
UserID: 1,
}},
}
wantUser2 = &models.User{
ID: uID,
RemoteIdentities: []models.RemoteIdentity{
{
ID: rID1,
Provider: provider1,
RemoteID: remoteID1,
AccessToken: secret1,
UserID: uID,
},
{
ID: rID2,
Provider: provider2,
RemoteID: remoteID2,
AccessToken: secret2,
UserID: uID,
},
},
}
)

db, cleanup := setup(t)
defer cleanup()

user1, err := db.GetUserByRemoteIdentity(provider, remoteID, initialToken)
user1, err := db.NewUserFromRemoteIdentity(provider1, remoteID1, secret1)
if err != nil {
t.Fatal(err)
}
Expand All @@ -103,7 +121,11 @@ func TestGormDBGetUserByRemoteIdentity(t *testing.T) {
t.Errorf("have user %+v want %+v", user1, wantUser1)
}

user2, err := db.GetUserByRemoteIdentity(provider, remoteID, newToken)
if err := db.AssociateUserWithRemoteIdentity(user1.ID, provider2, remoteID2, secret2); err != nil {
t.Fatal(err)
}

user2, err := db.GetUser(uID)
if err != nil {
t.Fatal(err)
}
Expand Down
99 changes: 45 additions & 54 deletions web/auth/auth.go
Expand Up @@ -2,16 +2,15 @@ package auth

import (
"encoding/gob"
"errors"
"net/http"
"strconv"

"github.com/autograde/aguis/database"
"github.com/autograde/aguis/models"
"github.com/autograde/aguis/scm"
"github.com/jinzhu/gorm"
"github.com/labstack/echo"
"github.com/labstack/echo-contrib/session"
"github.com/markbates/goth"
"github.com/markbates/goth/gothic"
)

Expand All @@ -21,6 +20,7 @@ func init() {

// Frontend URLs.
const (
home = "/app/home"
logout = "/app/logout"
login = "/app/newlogin"
)
Expand Down Expand Up @@ -128,34 +128,19 @@ func OAuth2Login(db database.Database) echo.HandlerFunc {
w := c.Response()
r := c.Request()

externalUser, err := gothic.CompleteUserAuth(w, r)
_, err := gothic.CompleteUserAuth(w, r)
// An error indicates that authentication needs to be performed at the provider.
if err != nil {
url, err := gothic.GetAuthURL(w, r)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
// Redirect to provider to perform authentication.
return c.Redirect(http.StatusTemporaryRedirect, url)
}

user, err := getInteralUser(db, &externalUser)
if err != nil {
return err
}

sess, err := session.Get(SessionKey, c)
if err != nil {
return err
}
if sess.Values[UserKey] == nil {
sess.Values[UserKey] = newUserSession(user.ID)
}
us := sess.Values[UserKey].(*UserSession)
us.enableProvider(c.Param("provider"))
if err := sess.Save(r, w); err != nil {
return err
}

return c.Redirect(http.StatusFound, login)
// The user navigated to /auth/:provider but is already authenticated.
return c.Redirect(http.StatusFound, home)
}
}

Expand All @@ -165,12 +150,13 @@ func OAuth2Callback(db database.Database) echo.HandlerFunc {
w := c.Response()
r := c.Request()

// Complete authentication.
externalUser, err := gothic.CompleteUserAuth(w, r)
if err != nil {
return echo.ErrUnauthorized
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}

user, err := getInteralUser(db, &externalUser)
remoteID, err := strconv.ParseUint(externalUser.UserID, 10, 64)
if err != nil {
return err
}
Expand All @@ -179,11 +165,43 @@ func OAuth2Callback(db database.Database) echo.HandlerFunc {
if err != nil {
return err
}
if sess.Values[UserKey] == nil {
sess.Values[UserKey] = newUserSession(user.ID)

// Try to get already logged in user.
if sess.Values[UserKey] != nil {
i, ok := sess.Values[UserKey]
if !ok {
return OAuth2Logout()(c)
}

us := i.(*UserSession)
// Associate user with remote identity.
if err := db.AssociateUserWithRemoteIdentity(
us.ID, externalUser.Provider, remoteID, externalUser.AccessToken,
); err != nil {
return err
}
return c.Redirect(http.StatusFound, login)
}
us := sess.Values[UserKey].(*UserSession)

// Try to get user from database.
var user *models.User
user, err = db.GetUserByRemoteIdentity(externalUser.Provider, remoteID, externalUser.AccessToken)
if err == gorm.ErrRecordNotFound {
// Create new user.
user, err = db.NewUserFromRemoteIdentity(
externalUser.Provider, remoteID, externalUser.AccessToken,
)
if err != nil {
return err
}
} else if err != nil {
return err
}

// Register user session.
us := newUserSession(user.ID)
us.enableProvider(c.Param("provider"))
sess.Values[UserKey] = us
if err := sess.Save(r, w); err != nil {
return err
}
Expand Down Expand Up @@ -238,30 +256,3 @@ func AccessControl(db database.Database, scms map[string]scm.SCM) echo.Middlewar
}
}
}

func getInteralUser(db database.Database, externalUser *goth.User) (*models.User, error) {
provider, err := goth.GetProvider(externalUser.Provider)

if err != nil {
return nil, err
}

// TODO: Extract each case into a function so that they can be tested.
switch provider.Name() {
case "github", "gitlab":
remoteID, err := strconv.ParseUint(externalUser.UserID, 10, 64)
if err != nil {
return nil, err
}

user, err := db.GetUserByRemoteIdentity(provider.Name(), remoteID, externalUser.AccessToken)
if err != nil {
return nil, err
}
return user, nil
case "faux": // Provider is only registered and reachable from tests.
return &models.User{}, nil
default:
return nil, errors.New("provider not implemented")
}
}

0 comments on commit 7de1163

Please sign in to comment.