Skip to content

Commit

Permalink
fix: Use transaction for CreateProgram
Browse files Browse the repository at this point in the history
* Guarantees consistency
* TODO: add tests
  • Loading branch information
mizlan committed Jul 27, 2023
1 parent 836670f commit 6c66184
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 42 deletions.
17 changes: 17 additions & 0 deletions db/alias_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,20 @@ func (d *DB) GetUIDFromWID(ctx context.Context, wid string, path string) (string

return t.Target, err
}

func (d *DB) GetUIDFromWIDTransact(tx *firestore.Transaction, wid string, path string) (string, error) {

// get the document with the mapping
doc, err := tx.Get(d.Collection(path).Doc(wid))
if err != nil {
return "", err
}

t := struct {
Target string `firestore:"target"`
}{}

err = doc.DataTo(&t)

return t.Target, err
}
110 changes: 110 additions & 0 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import (
"google.golang.org/api/option"
)

/// NOTE:
/// The *Xyz*Transact() functions are equivalent in behavior to *Xyz*(),
/// except they operate on a transaction instead of a context and should
/// be called within a RunTransaction() callback

// DB implements the TLADB interface on a Firestore
// database.
type DB struct {
Expand Down Expand Up @@ -68,6 +73,71 @@ func (d *DB) CreateProgram(ctx context.Context, p Program) (Program, error) {
return p, nil
}

func (d *DB) CreateProgramTransact(tx *firestore.Transaction, p Program) (Program, error) {
newProg := d.Collection(programsPath).NewDoc()
p.UID = newProg.ID
if err := tx.Create(newProg, p); err != nil {
return p, err
}

return p, nil
}

// Create a program and associate it with a user and a class.
//
// If wid == "", will not attempt to join a class.
func (d *DB) CreateProgramAndAssociate(ctx context.Context, p Program, uid string, wid string) error {
err := d.RunTransaction(ctx, func(ctx context.Context, tx *firestore.Transaction) error {
// create program
pRef, err := d.CreateProgramTransact(tx, p)
if err != nil {
return err
}

// associate to user, if they exist
u, err := d.LoadUserTransact(tx, uid)
if err != nil {
return err
}

u.Programs = append(u.Programs, pRef.UID)

if err := d.StoreUserTransact(tx, u); err != nil {
return err
}

// associate to class, if they exist
var cid string
var class Class
if wid != "" {
cid, err = d.GetUIDFromWIDTransact(tx, wid, ClassesAliasPath)
if err != nil {
return err
}

class, err = d.LoadClassTransact(tx, cid)
if err != nil {
return err
}

class.Programs = append(class.Programs, pRef.UID)

p.WID = class.WID

err := d.StoreClassTransact(tx, class)
if err != nil {
return err
}
}

p.UID = pRef.UID

return nil
})

return err
}

func (d *DB) RemoveProgram(ctx context.Context, pid string) error {
if _, err := d.Collection(programsPath).Doc(pid).Delete(ctx); err != nil {
return err
Expand All @@ -88,13 +158,33 @@ func (d *DB) LoadClass(ctx context.Context, cid string) (Class, error) {
return c, nil
}

func (d *DB) LoadClassTransact(tx *firestore.Transaction, cid string) (Class, error) {
doc, err := tx.Get(d.Collection(classesPath).Doc(cid))
if err != nil {
return Class{}, err
}

c := Class{}
if err := doc.DataTo(&c); err != nil {
return Class{}, err
}
return c, nil
}

func (d *DB) StoreClass(ctx context.Context, c Class) error {
if _, err := d.Collection(classesPath).Doc(c.CID).Set(ctx, &c); err != nil {
return err
}
return nil
}

func (d *DB) StoreClassTransact(tx *firestore.Transaction, c Class) error {
if err := tx.Set(d.Collection(classesPath).Doc(c.CID), &c); err != nil {
return err
}
return nil
}

func (d *DB) DeleteClass(ctx context.Context, cid string) error {
if _, err := d.Collection(classesPath).Doc(cid).Delete(ctx); err != nil {
return err
Expand All @@ -116,13 +206,33 @@ func (d *DB) LoadUser(ctx context.Context, uid string) (User, error) {
return u, nil
}

func (d *DB) LoadUserTransact(tx *firestore.Transaction, uid string) (User, error) {
doc, err := tx.Get(d.Collection(usersPath).Doc(uid))
if err != nil {
return User{}, err
}

u := User{}
if err := doc.DataTo(&u); err != nil {
return User{}, err
}
return u, nil
}

func (d *DB) StoreUser(ctx context.Context, u User) error {
if _, err := d.Collection(usersPath).Doc(u.UID).Set(ctx, &u); err != nil {
return err
}
return nil
}

func (d *DB) StoreUserTransact(tx *firestore.Transaction, u User) error {
if err := tx.Set(d.Collection(usersPath).Doc(u.UID), &u); err != nil {
return err
}
return nil
}

func (d *DB) DeleteUser(ctx context.Context, uid string) error {
if _, err := d.Collection(usersPath).Doc(uid).Delete(ctx); err != nil {
return err
Expand Down
23 changes: 23 additions & 0 deletions db/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,29 @@ func (d *MockDB) CreateProgram(_ context.Context, p Program) (Program, error) {
return p, nil
}

func (d *MockDB) CreateProgramAndAssociate(ctx context.Context, p Program, uid string, wid string) error {
// create program
pRef, _ := d.CreateProgram(ctx, p)

// associate to user, if they exist
u, _ := d.LoadUser(ctx, uid)

u.Programs = append(u.Programs, pRef.UID)

if err := d.StoreUser(ctx, u); err != nil {
return err
}

if wid != "" {

Check failure on line 106 in db/mock.go

View workflow job for this annotation

GitHub Actions / lint

SA9003: empty branch (staticcheck)
// do nothing for now
// TODO: associate with class
}

p.UID = pRef.UID

return nil
}

// Temporary stand-ins to allow other refactors to function
func (d *MockDB) MakeAlias(ctx context.Context, uid string, path string) (string, error) {
return "", nil
Expand Down
1 change: 1 addition & 0 deletions db/tladb.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type TLADB interface {

CreateUser(context.Context, User) (User, error)
CreateProgram(context.Context, Program) (Program, error)
CreateProgramAndAssociate(context.Context, Program, string, string) error

MakeAlias(context.Context, string, string) (string, error)
GetUIDFromWID(context.Context, string, string) (string, error)
Expand Down
44 changes: 2 additions & 42 deletions handler/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,48 +70,8 @@ func CreateProgram(cc echo.Context) error {
p.Name = requestBody.Prog.Name
}

wid := requestBody.WID
var cid string
var class db.Class
var err error

if wid != "" {
cid, err = c.GetUIDFromWID(c.Request().Context(), wid, db.ClassesAliasPath)
if err != nil {
return err
}

class, err = c.LoadClass(c.Request().Context(), cid)
if err != nil {
return err
}
}

// create program
pRef, _ := c.CreateProgram(c.Request().Context(), p)

// associate to user, if they exist
u, _ := c.LoadUser(c.Request().Context(), requestBody.UID)

u.Programs = append(u.Programs, pRef.UID)

if err := c.StoreUser(c.Request().Context(), u); err != nil {
return err
}

// associate to class, if they exist
if wid != "" {
classRef, _ := c.LoadClass(c.Request().Context(), cid)
classRef.Programs = append(classRef.Programs, pRef.UID)

p.WID = class.WID
err := c.StoreClass(c.Request().Context(), classRef)
if err != nil {
return err
}
}

p.UID = pRef.UID
// use the composite function to guarantee consistency
err := c.CreateProgramAndAssociate(c.Request().Context(), p, requestBody.UID, requestBody.WID)

if err != nil {
if status.Code(err) == codes.NotFound {
Expand Down

0 comments on commit 6c66184

Please sign in to comment.