Skip to content
Permalink
Browse files

Unit tests, integration testing, and code cleanup for oauth support. …

…Part of T705.
  • Loading branch information
ngerakines committed Dec 23, 2019
1 parent 7a0863f commit bf3b6a5ba01fbe1c3b726676da326b66e2ad4a4e
Showing with 777 additions and 52 deletions.
  1. +6 −1 database.go
  2. +43 −0 database_test.go
  3. +271 −0 db/create.go
  4. +146 −0 db/create_test.go
  5. +26 −0 db/tx.go
  6. +1 −1 go.mod
  7. +0 −1 go.sum
  8. +137 −2 main_test.go
  9. +1 −0 migrations/migrations.go
  10. +46 −0 migrations/v4.go
  11. +70 −39 oauth.go
  12. +20 −8 oauth_test.go
  13. +10 −0 routes.go
@@ -128,6 +128,11 @@ type writestore interface {
GetUserLastPostTime(id int64) (*time.Time, error)
GetCollectionLastPostTime(id int64) (*time.Time, error)

GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error)
RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error
ValidateOAuthState(ctx context.Context, state string) error
GenerateOAuthState(ctx context.Context) (string, error)

DatabaseInitialized() bool
}

@@ -2489,7 +2494,7 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote
if db.driverName == driverSQLite {
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
} else {
_, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id"), localUserID, remoteUserID)
_, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id") + " user_id = ?", localUserID, remoteUserID, localUserID)
}
if err != nil {
log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err)
@@ -0,0 +1,43 @@
package writefreely

import (
"context"
"database/sql"
"github.com/stretchr/testify/assert"
"testing"
)

func TestOAuthDatastore(t *testing.T) {
if !runMySQLTests() {
t.Skip("skipping mysql tests")
}
withTestDB(t, func(db *sql.DB) {
ctx := context.Background()
ds := &datastore{
DB: db,
driverName: "",
}

state, err := ds.GenerateOAuthState(ctx)
assert.NoError(t, err)
assert.Len(t, state, 24)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state)

err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state)

var localUserID int64 = 99
var remoteUserID int64 = 100
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID)
assert.NoError(t, err)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `users_oauth` WHERE `user_id` = ? AND `remote_user_id` = ?", localUserID, remoteUserID)

foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID)
assert.NoError(t, err)
assert.Equal(t, localUserID, foundUserID)
})
}
@@ -0,0 +1,271 @@
package db

import (
"fmt"
"strings"
)

type DialectType int
type ColumnType int

type OptionalInt struct {
Set bool
Value int
}

type OptionalString struct {
Set bool
Value string
}

type SQLBuilder interface {
ToSQL() (string, error)
}

type Column struct {
Dialect DialectType
Name string
Nullable bool
Default OptionalString
Type ColumnType
Size OptionalInt
PrimaryKey bool
}

type CreateTableSqlBuilder struct {
Dialect DialectType
Name string
IfNotExists bool
ColumnOrder []string
Columns map[string]*Column
Constraints []string
}

const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)

const (
ColumnTypeBool ColumnType = iota
ColumnTypeSmallInt ColumnType = iota
ColumnTypeInteger ColumnType = iota
ColumnTypeChar ColumnType = iota
ColumnTypeVarChar ColumnType = iota
ColumnTypeText ColumnType = iota
ColumnTypeDateTime ColumnType = iota
)

var _ SQLBuilder = &CreateTableSqlBuilder{}

var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}

func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}

func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}

func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
if dialect != DialectMySQL && dialect != DialectSQLite {
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
}
switch d {
case ColumnTypeSmallInt:
{
if dialect == DialectSQLite {
return "INTEGER", nil
}
mod := ""
if size.Set {
mod = fmt.Sprintf("(%d)", size.Value)
}
return "SMALLINT" + mod, nil
}
case ColumnTypeInteger:
{
if dialect == DialectSQLite {
return "INTEGER", nil
}
mod := ""
if size.Set {
mod = fmt.Sprintf("(%d)", size.Value)
}
return "INT" + mod, nil
}
case ColumnTypeChar:
{
if dialect == DialectSQLite {
return "TEXT", nil
}
mod := ""
if size.Set {
mod = fmt.Sprintf("(%d)", size.Value)
}
return "CHAR" + mod, nil
}
case ColumnTypeVarChar:
{
if dialect == DialectSQLite {
return "TEXT", nil
}
mod := ""
if size.Set {
mod = fmt.Sprintf("(%d)", size.Value)
}
return "VARCHAR" + mod, nil
}
case ColumnTypeBool:
{
if dialect == DialectSQLite {
return "INTEGER", nil
}
return "TINYINT(1)", nil
}
case ColumnTypeDateTime:
return "DATETIME", nil
case ColumnTypeText:
return "TEXT", nil
}
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
}

func (c *Column) SetName(name string) *Column {
c.Name = name
return c
}

func (c *Column) SetNullable(nullable bool) *Column {
c.Nullable = nullable
return c
}

func (c *Column) SetPrimaryKey(pk bool) *Column {
c.PrimaryKey = pk
return c
}

func (c *Column) SetDefault(value string) *Column {
c.Default = OptionalString{Set: true, Value: value}
return c
}

func (c *Column) SetType(t ColumnType) *Column {
c.Type = t
return c
}

func (c *Column) SetSize(size int) *Column {
c.Size = OptionalInt{Set: true, Value: size}
return c
}

func (c *Column) String() (string, error) {
var str strings.Builder

str.WriteString(c.Name)

str.WriteString(" ")
typeStr, err := c.Type.Format(c.Dialect, c.Size)
if err != nil {
return "", err
}

str.WriteString(typeStr)

if !c.Nullable {
str.WriteString(" NOT NULL")
}

if c.Default.Set {
str.WriteString(" DEFAULT ")
str.WriteString(c.Default.Value)
}

if c.PrimaryKey {
str.WriteString(" PRIMARY KEY")
}

return str.String(), nil
}

func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
if b.Columns == nil {
b.Columns = make(map[string]*Column)
}
b.Columns[column.Name] = column
b.ColumnOrder = append(b.ColumnOrder, column.Name)
return b
}

func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder {
for _, column := range columns {
if _, ok := b.Columns[column]; !ok {
// This fails silently.
return b
}
}
b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ",")))
return b
}

func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder {
b.IfNotExists = ine
return b
}

func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder

str.WriteString("CREATE TABLE ")
if b.IfNotExists {
str.WriteString("IF NOT EXISTS ")
}
str.WriteString(b.Name)

var things []string
for _, columnName := range b.ColumnOrder {
column, ok := b.Columns[columnName]
if !ok {
return "", fmt.Errorf("column not found: %s", columnName)
}
columnStr, err := column.String()
if err != nil {
return "", err
}
things = append(things, columnStr)
}
for _, constraint := range b.Constraints {
things = append(things, constraint)
}

if thingLen := len(things); thingLen > 0 {
str.WriteString(" ( ")
for i, thing := range things {
str.WriteString(thing)
if i < thingLen-1 {
str.WriteString(", ")
}
}
str.WriteString(" )")
}

return str.String(), nil
}

0 comments on commit bf3b6a5

Please sign in to comment.
You can’t perform that action at this time.