Skip to content

Commit

Permalink
Utilize go1.8 context support in database/sql
Browse files Browse the repository at this point in the history
Fixes go-gorm#1231

The related go1.8 release notes: https://golang.org/doc/go1.8#database_sql
  • Loading branch information
remohammadi committed Nov 18, 2017
1 parent 0a51f6c commit fb51756
Show file tree
Hide file tree
Showing 29 changed files with 580 additions and 181 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -20,6 +20,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
* Extendable, write Plugins based on GORM callbacks
* Every feature comes with tests
* Developer Friendly
* Supports context.Context on golang 1.8

## Getting Started

Expand Down
4 changes: 2 additions & 2 deletions callback_create.go
Expand Up @@ -115,7 +115,7 @@ func createCallback(scope *Scope) {

// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if result, err := scope.sqldbExec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()

Expand All @@ -128,7 +128,7 @@ func createCallback(scope *Scope) {
}
} else {
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
if err := scope.sqldbQueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
Expand Down
2 changes: 1 addition & 1 deletion callback_query.go
Expand Up @@ -55,7 +55,7 @@ func queryCallback(scope *Scope) {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
}

if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if rows, err := scope.sqldbQuery(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close()

columns, _ := rows.Columns()
Expand Down
4 changes: 2 additions & 2 deletions callback_row_query.go
Expand Up @@ -22,9 +22,9 @@ func rowQueryCallback(scope *Scope) {
scope.prepareQuerySQL()

if rowResult, ok := result.(*RowQueryResult); ok {
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
rowResult.Row = scope.sqldbQueryRow(scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
rowsResult.Rows, rowsResult.Error = scope.sqldbQuery(scope.SQL, scope.SQLVars...)
}
}
}
36 changes: 8 additions & 28 deletions dialect_common.go
Expand Up @@ -9,6 +9,14 @@ import (
"time"
)

const (
queryHasIndex = "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?"
queryRemoveIndex = "DROP INDEX %v"
queryHasTable = "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?"
queryHasColumn = "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?"
queryCurrentDatabase = "SELECT DATABASE()"
)

// DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct {
}
Expand Down Expand Up @@ -90,38 +98,10 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}

func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}

func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}

func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
return count > 0
}

func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}

func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}

func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
Expand Down
36 changes: 36 additions & 0 deletions dialect_common_go1.8.go
@@ -0,0 +1,36 @@
// +build go1.8

package gorm

import (
"context"
"fmt"
)

func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryHasIndex, s.CurrentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}

func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryRemoveIndex, indexName))
return err
}

func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryHasTable, s.CurrentDatabase(), tableName).Scan(&count)
return count > 0
}

func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}

func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRowContext(context.Background(), queryCurrentDatabase).Scan(&name)
return
}
33 changes: 33 additions & 0 deletions dialect_common_go1.8pre.go
@@ -0,0 +1,33 @@
// +build !go1.8

package gorm

import "fmt"

func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(queryHasIndex, s.CurrentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}

func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf(queryRemoveIndex, indexName))
return err
}

func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow(queryHasTable, s.CurrentDatabase(), tableName).Scan(&count)
return count > 0
}

func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(queryHasColumn, s.CurrentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}

func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow(queryCurrentDatabase).Scan(&name)
return
}
22 changes: 6 additions & 16 deletions dialect_mysql.go
Expand Up @@ -11,6 +11,12 @@ import (
"unicode/utf8"
)

const (
queryMySQLRemoveIndex = "DROP INDEX %v ON %v"
queryMySQLHasForeignKey = "SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'"
queryMySQLCurrentDatabase = "SELECT DATABASE()"
)

type mysql struct {
commonDialect
}
Expand Down Expand Up @@ -122,11 +128,6 @@ func (s *mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}

func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
Expand All @@ -142,17 +143,6 @@ func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
return
}

func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
return count > 0
}

func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}

func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
Expand Down
24 changes: 24 additions & 0 deletions dialect_mysql_go1.8.go
@@ -0,0 +1,24 @@
// +build go1.8

package gorm

import (
"context"
"fmt"
)

func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.ExecContext(context.Background(), fmt.Sprintf(queryMySQLRemoveIndex, indexName, s.Quote(tableName)))
return err
}

func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryMySQLHasForeignKey, s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
return count > 0
}

func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRowContext(context.Background(), queryMySQLCurrentDatabase).Scan(&name)
return
}
21 changes: 21 additions & 0 deletions dialect_mysql_go1.8pre.go
@@ -0,0 +1,21 @@
// +build !go1.8

package gorm

import "fmt"

func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf(queryMySQLRemoveIndex, indexName, s.Quote(tableName)))
return err
}

func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow(queryMySQLHasForeignKey, s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
return count > 0
}

func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRow(queryMySQLCurrentDatabase).Scan(&name)
return
}
37 changes: 8 additions & 29 deletions dialect_postgres.go
Expand Up @@ -7,6 +7,14 @@ import (
"time"
)

const (
queryPostgresHasIndex = "SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()"
queryPostgresHasForeignKey = "SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'"
queryPostgresHasTable = "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()"
queryPostgresHasColumn = "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()"
queryPostgresCurrentDatabase = "SELECT CURRENT_DATABASE()"
)

type postgres struct {
commonDialect
}
Expand Down Expand Up @@ -85,35 +93,6 @@ func (s *postgres) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0
}

func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
return count > 0
}

func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0
}

func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0
}

func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
return
}

func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}
Expand Down
34 changes: 34 additions & 0 deletions dialect_postgres_go1.8.go
@@ -0,0 +1,34 @@
// +build go1.8

package gorm

import "context"

func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryPostgresHasIndex, tableName, indexName).Scan(&count)
return count > 0
}

func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryPostgresHasForeignKey, tableName, foreignKeyName).Scan(&count)
return count > 0
}

func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryPostgresHasTable, tableName).Scan(&count)
return count > 0
}

func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRowContext(context.Background(), queryPostgresHasColumn, tableName, columnName).Scan(&count)
return count > 0
}

func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRowContext(context.Background(), queryPostgresCurrentDatabase).Scan(&name)
return
}
32 changes: 32 additions & 0 deletions dialect_postgres_go1.8pre.go
@@ -0,0 +1,32 @@
// +build !go1.8

package gorm

func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(queryPostgresHasIndex, tableName, indexName).Scan(&count)
return count > 0
}

func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow(queryPostgresHasForeignKey, tableName, foreignKeyName).Scan(&count)
return count > 0
}

func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRow(queryPostgresHasTable, tableName).Scan(&count)
return count > 0
}

func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(queryPostgresHasColumn, tableName, columnName).Scan(&count)
return count > 0
}

func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRow(queryPostgresCurrentDatabase).Scan(&name)
return
}

0 comments on commit fb51756

Please sign in to comment.