Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth Provider: Slack #230

Merged
merged 15 commits into from Jan 3, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 19 additions & 12 deletions config/config.go
Expand Up @@ -56,6 +56,20 @@ type (
Port int `ini:"port"`
}

WriteAsOauthCfg struct {
ClientID string `ini:"client_id"`
ClientSecret string `ini:"client_secret"`
AuthLocation string `ini:"auth_location"`
TokenLocation string `ini:"token_location"`
InspectLocation string `ini:"inspect_location"`
}

SlackOauthCfg struct {
ClientID string `ini:"client_id"`
ClientSecret string `ini:"client_secret"`
TeamID string `ini:"team_id"`
}

// AppCfg holds values that affect how the application functions
AppCfg struct {
SiteName string `ini:"site_name"`
Expand Down Expand Up @@ -92,24 +106,17 @@ type (
LocalTimeline bool `ini:"local_timeline"`
UserInvites string `ini:"user_invites"`

// OAuth
EnableOAuth bool `ini:"enable_oauth"`
OAuthProviderAuthLocation string `ini:"oauth_auth_location"`
OAuthProviderTokenLocation string `ini:"oauth_token_location"`
OAuthProviderInspectLocation string `ini:"oauth_inspect_location"`
OAuthClientCallbackLocation string `ini:"oauth_callback_location"`
OAuthClientID string `ini:"oauth_client_id"`
OAuthClientSecret string `ini:"oauth_client_secret"`

// Defaults
DefaultVisibility string `ini:"default_visibility"`
}

// Config holds the complete configuration for running a writefreely instance
Config struct {
Server ServerCfg `ini:"server"`
Database DatabaseCfg `ini:"database"`
App AppCfg `ini:"app"`
Server ServerCfg `ini:"server"`
Database DatabaseCfg `ini:"database"`
App AppCfg `ini:"app"`
SlackOauth SlackOauthCfg `ini:"oauth.slack"`
WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"`
}
)

Expand Down
15 changes: 15 additions & 0 deletions config/funcs.go
Expand Up @@ -11,7 +11,9 @@
package config

import (
"net/http"
"strings"
"time"
)

// FriendlyHost returns the app's Host sans any schema
Expand All @@ -25,3 +27,16 @@ func (ac AppCfg) CanCreateBlogs(currentlyUsed uint64) bool {
}
return int(currentlyUsed) < ac.MaxBlogs
}

// OrDefaultString returns input or a default value if input is empty.
func OrDefaultString(input, defaultValue string) string {
if len(input) == 0 {
return defaultValue
}
return input
}

// DefaultHTTPClient returns a sane default HTTP client.
func DefaultHTTPClient() *http.Client {
return &http.Client{Timeout: 10 * time.Second}
}
60 changes: 38 additions & 22 deletions database.go
Expand Up @@ -14,6 +14,7 @@ import (
"context"
"database/sql"
"fmt"
wf_db "github.com/writeas/writefreely/db"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -125,10 +126,10 @@ type writestore interface {
GetUserLastPostTime(id int64) (*time.Time, error)
GetCollectionLastPostTime(id int64) (*time.Time, error)

GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error)
RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error
ValidateOAuthState(ctx context.Context, state string) error
GenerateOAuthState(ctx context.Context) (string, error)
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error)

DatabaseInitialized() bool
}
Expand All @@ -138,6 +139,8 @@ type datastore struct {
driverName string
}

var _ writestore = &datastore{}

func (db *datastore) now() string {
if db.driverName == driverSQLite {
return "strftime('%Y-%m-%d %H:%M:%S','now')"
Expand Down Expand Up @@ -2459,36 +2462,49 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
return &t, nil
}

func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) {
func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) {
state := store.Generate62RandomString(24)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, used, created_at) VALUES (?, FALSE, NOW())", state)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID)
if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err)
}
return state, nil
}

func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error {
res, err := db.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
var provider string
var clientID string
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID)
if err != nil {
return err
}

res, err := tx.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return fmt.Errorf("state not found")
}
return nil
})
if err != nil {
return err
return "", "", nil
}
if rowsAffected != 1 {
return fmt.Errorf("state not found")
}
return nil
return provider, clientID, nil
}

func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error {
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
var err error
if db.driverName == driverSQLite {
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?)", localUserID, remoteUserID, provider, clientID, accessToken)
} else {
_, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID)
_, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?) "+db.upsert("user")+" access_token = ?", localUserID, remoteUserID, provider, clientID, accessToken, accessToken)
}
if err != nil {
log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err)
Expand All @@ -2497,10 +2513,10 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote
}

// GetIDForRemoteUser returns a user ID associated with a remote user ID.
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) {
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
var userID int64 = -1
err := db.
QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ?", remoteUserID).
QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ? AND provider = ? AND client_id = ?", remoteUserID, provider, clientID).
Scan(&userID)
// Not finding a record is OK.
if err != nil && err != sql.ErrNoRows {
Expand Down
19 changes: 13 additions & 6 deletions database_test.go
Expand Up @@ -18,25 +18,32 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "",
}

state, err := ds.GenerateOAuthState(ctx)
state, err := ds.GenerateOAuthState(ctx, "test", "development")
assert.NoError(t, err)
assert.Len(t, state, 24)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)

err = ds.ValidateOAuthState(ctx, state)
_, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)

var localUserID int64 = 99
var remoteUserID int64 = 100
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID)
var remoteUserID = "100"
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_a")
assert.NoError(t, err)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ?", localUserID, remoteUserID)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_a'", localUserID, remoteUserID)

foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID)
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_b")
assert.NoError(t, err)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_b'", localUserID, remoteUserID)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users`")

foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID, "test", "test")
assert.NoError(t, err)
assert.Equal(t, localUserID, foundUserID)
})
Expand Down
52 changes: 52 additions & 0 deletions db/alter.go
@@ -0,0 +1,52 @@
package db

import (
"fmt"
"strings"
)

type AlterTableSqlBuilder struct {
Dialect DialectType
Name string
Changes []string
}

func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
}
return b
}

func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal))
}
return b
}

func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder {
b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", ")))
return b
}

func (b *AlterTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder

str.WriteString("ALTER TABLE ")
str.WriteString(b.Name)
str.WriteString(" ")

if len(b.Changes) == 0 {
return "", fmt.Errorf("no changes provide for table: %s", b.Name)
}
changeCount := len(b.Changes)
for i, thing := range b.Changes {
str.WriteString(thing)
if i < changeCount-1 {
str.WriteString(", ")
}
}

return str.String(), nil
}
56 changes: 56 additions & 0 deletions db/alter_test.go
@@ -0,0 +1,56 @@
package db

import "testing"

func TestAlterTableSqlBuilder_ToSQL(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Changes []string
}
tests := []struct {
name string
builder *AlterTableSqlBuilder
want string
wantErr bool
}{
{
name: "MySQL add int",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)),
want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL",
wantErr: false,
},
{
name: "MySQL add string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL",
wantErr: false,
},

{
name: "MySQL add int and string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)).
AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.builder.ToSQL()
if (err != nil) != tt.wantErr {
t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ToSQL() got = %v, want %v", got, tt.want)
}
})
}
}
29 changes: 1 addition & 28 deletions db/create.go
Expand Up @@ -5,7 +5,6 @@ import (
"strings"
)

type DialectType int
type ColumnType int

type OptionalInt struct {
Expand Down Expand Up @@ -41,11 +40,6 @@ type CreateTableSqlBuilder struct {
Constraints []string
}

const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)

const (
ColumnTypeBool ColumnType = iota
ColumnTypeSmallInt ColumnType = iota
Expand All @@ -61,28 +55,6 @@ var _ SQLBuilder = &CreateTableSqlBuilder{}
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}

func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}

func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}

func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
if dialect != DialectMySQL && dialect != DialectSQLite {
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
Expand Down Expand Up @@ -269,3 +241,4 @@ func (b *CreateTableSqlBuilder) ToSQL() (string, error) {

return str.String(), nil
}