Skip to content

Commit

Permalink
feat: constrained scope to only dropping nullable and add support for…
Browse files Browse the repository at this point in the history
… mysql
  • Loading branch information
vabshere committed Jul 16, 2020
1 parent a2832f1 commit c9d8908
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 60 deletions.
4 changes: 3 additions & 1 deletion dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type Dialect interface {
Quote(key string) string
// DataTypeOf return data's sql type
DataTypeOf(field *StructField) string
// SplitDataTypeOf returns data's sql type and it's additional type
SplitDataTypeOf(field *StructField) (string, string)

// HasIndex check has index or not
HasIndex(tableName string, indexName string) bool
Expand All @@ -38,7 +40,7 @@ type Dialect interface {
// ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error
// Nullable sets column's null constraint
Nullable(tableName string, columnName string, colType string, isNull bool) error
DropNullable(tableName string, columnName string, colType string) error

// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) string
Expand Down
23 changes: 12 additions & 11 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
}

func (s *commonDialect) DataTypeOf(field *StructField) string {
var sqlType, additionalType = s.SplitDataTypeOf(field)

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s *commonDialect) SplitDataTypeOf(field *StructField) (string, string) {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)

if sqlType == "" {
Expand Down Expand Up @@ -93,10 +102,7 @@ func (s *commonDialect) DataTypeOf(field *StructField) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
}

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
return sqlType, additionalType
}

func (s commonDialect) HasIndex(tableName string, indexName string) bool {
Expand Down Expand Up @@ -139,13 +145,8 @@ func (s commonDialect) ModifyColumn(tableName string, columnName string, typ str
return err
}

func (s commonDialect) Nullable(tableName string, columnName string, colType string, isNull bool) error {
var err error
if isNull {
_, err = s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v NULL", tableName, columnName, colType))
} else {
_, err = s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v NOT NULL", tableName, columnName, colType))
}
func (s commonDialect) DropNullable(tableName string, columnName string, colType string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v NULL", tableName, columnName, colType))
return err
}

Expand Down
14 changes: 10 additions & 4 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ func (mysql) Quote(key string) string {

// Get Data Type for MySQL Dialect
func (s *mysql) DataTypeOf(field *StructField) string {
var sqlType, additionalType = s.SplitDataTypeOf(field)

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s *mysql) SplitDataTypeOf(field *StructField) (string, string) {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)

// MySQL allows only one auto increment column per table, and it must
Expand Down Expand Up @@ -129,10 +138,7 @@ func (s *mysql) DataTypeOf(field *StructField) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
}

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
return sqlType, additionalType
}

func (s mysql) RemoveIndex(tableName string, indexName string) error {
Expand Down
47 changes: 23 additions & 24 deletions dialect_oci8.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"crypto/sha1"
"fmt"
ociDriver "github.com/mattn/go-oci8"
"reflect"
"strconv"
"strings"
"reflect"
"time"
"strings"
"unicode/utf8"
)

Expand Down Expand Up @@ -40,11 +40,20 @@ func (*oci8) BindVar(i int) string {
}

func (o *oci8) DataTypeOf(field *StructField) string {
var sqlType, additionalType = o.SplitDataTypeOf(field)

if len(strings.TrimSpace(additionalType)) == 0 {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (o *oci8) SplitDataTypeOf(field *StructField) (string, string) {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, o)

charset, _ := o.GetTagSetting(field,"CHARSET")
charset, _ := o.GetTagSetting(field, "CHARSET")
var strDataType string
if strings.EqualFold(charset,"utf-8"){
if strings.EqualFold(charset, "utf-8") {
strDataType = "NVARCHAR2"
} else {
strDataType = "VARCHAR2"
Expand Down Expand Up @@ -90,7 +99,7 @@ func (o *oci8) DataTypeOf(field *StructField) string {
}
}

} else if isUUID(dataValue){
} else if isUUID(dataValue) {
// In case the user has specified uuid as the type explicitly
sqlType = fmt.Sprintf("%s(36)", strDataType)
}
Expand All @@ -99,10 +108,7 @@ func (o *oci8) DataTypeOf(field *StructField) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for oci8", dataValue.Type().Name(), dataValue.Kind().String()))
}

if len(strings.TrimSpace(additionalType)) == 0 {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
return sqlType, additionalType
}

func (o *oci8) HasIndex(tableName string, indexName string) bool {
Expand All @@ -129,17 +135,11 @@ func (o *oci8) HasColumn(tableName string, columnName string) bool {
return count > 0
}

func (s *oci8) Nullable(tableName string, columnName string, colType string, isNull bool) error {
var err error
if isNull {
_, err = s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY %v NULL", tableName, columnName))
} else {
_, err = s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY %v NOT NULL", tableName, columnName))
}
func (s *oci8) DropNullable(tableName string, columnName string, colType string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v NULL", tableName, columnName, colType))
return err
}


func (*oci8) buildSha(str string) string {
if utf8.RuneCountInString(str) <= 30 {
return str
Expand All @@ -158,13 +158,13 @@ func (*oci8) buildSha(str string) string {

// Returns the primary key via the row ID
// Assumes that the primary key is the ID of the table
func (o *oci8) ResolveRowID(tableName string, rowID uint) uint{
func (o *oci8) ResolveRowID(tableName string, rowID uint) uint {
strRowID := ociDriver.GetLastInsertId(int64(rowID))
var id string
query := fmt.Sprintf(`SELECT id FROM %s WHERE rowid = :2`, o.Quote(tableName))
var err error
if err = o.db.QueryRow(query, strRowID).Scan(&id); err == nil{
if res, err := strconv.ParseUint(id, 10, 64); err == nil{
if err = o.db.QueryRow(query, strRowID).Scan(&id); err == nil {
if res, err := strconv.ParseUint(id, 10, 64); err == nil {
resolvedId := uint(res)
return resolvedId
}
Expand All @@ -173,14 +173,14 @@ func (o *oci8) ResolveRowID(tableName string, rowID uint) uint{
}

// Client statement separator used to terminate the statement
func (*oci8) ClientStatementSeparator() string{
func (*oci8) ClientStatementSeparator() string {
// In case of most DB's, it's a semicolon
return ""
}

func (*oci8) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
// In case both limit and offset are nil, simply return and empty string
if offset == nil && limit == nil{
if offset == nil && limit == nil {
return ""
}

Expand All @@ -203,7 +203,7 @@ func (*oci8) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
}

// Limit clause comes later
if errLimitParse == nil && parsedLimit >= 0 {
if errLimitParse == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" ROWS FETCH NEXT %d ROWS ONLY", parsedLimit)
}
return
Expand All @@ -224,4 +224,3 @@ func (o *oci8) GetTagSetting(field *StructField, key string) (val string, ok boo
func (o *oci8) GetByteLimit() int {
return 30000
}

23 changes: 12 additions & 11 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ func (postgres) BindVar(i int) string {
}

func (s *postgres) DataTypeOf(field *StructField) string {
var sqlType, additionalType = s.SplitDataTypeOf(field)

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s *postgres) SplitDataTypeOf(field *StructField) (string, string) {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)

if sqlType == "" {
Expand Down Expand Up @@ -85,10 +94,7 @@ func (s *postgres) DataTypeOf(field *StructField) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
}

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
return sqlType, additionalType
}

func (s postgres) HasIndex(tableName string, indexName string) bool {
Expand All @@ -115,13 +121,8 @@ func (s postgres) HasColumn(tableName string, columnName string) bool {
return count > 0
}

func (s postgres) Nullable(tableName string, columnName string, colType string, isNull bool) error {
var err error
if isNull {
_, err = s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v DROP NOT NULL", tableName, columnName))
} else {
_, err = s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v SET NOT NULL", tableName, columnName))
}
func (s postgres) DropNullable(tableName string, columnName string, colType string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v DROP NOT NULL", tableName, columnName))
return err
}

Expand Down
14 changes: 10 additions & 4 deletions dialect_sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ func (sqlite3) GetName() string {

// Get Data Type for Sqlite Dialect
func (s *sqlite3) DataTypeOf(field *StructField) string {
var sqlType, additionalType = s.SplitDataTypeOf(field)

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s *sqlite3) SplitDataTypeOf(field *StructField) (string,string) {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)

if sqlType == "" {
Expand Down Expand Up @@ -64,10 +73,7 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
}

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
return sqlType, additionalType
}

func (s sqlite3) HasIndex(tableName string, indexName string) bool {
Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,9 @@ func (s *DB) ModifyColumn(column string, typ string) *DB {
}

// Nullable sets column's null constraint
func (s *DB) Nullable(column string, isNull bool) *DB {
func (s *DB) DropNullable(column string) *DB {
scope := s.NewScope(s.Value)
scope.nullable(column, isNull)
scope.dropNullable(column)
return scope.db
}

Expand Down
6 changes: 3 additions & 3 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -1223,13 +1223,13 @@ func (scope *Scope) modifyColumn(column string, typ string) {
scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ))
}

func (scope *Scope) nullable(column string, isNull bool) {
func (scope *Scope) dropNullable(column string) {
colField, ok := scope.FieldByName(column)
if !ok {
scope.db.AddError(errors.New("No such column found"))
}
colType := scope.Dialect().DataTypeOf(colField.StructField)
scope.db.AddError(scope.Dialect().Nullable(scope.QuotedTableName(), scope.Quote(column), colType, isNull))
colType, _ := scope.Dialect().SplitDataTypeOf(colField.StructField)
scope.db.AddError(scope.Dialect().DropNullable(scope.QuotedTableName(), scope.Quote(column), colType))
}

func (scope *Scope) dropColumn(column string) {
Expand Down

0 comments on commit c9d8908

Please sign in to comment.