diff --git a/.codeclimate.yml b/.codeclimate.yml deleted file mode 100644 index 51aba50c..00000000 --- a/.codeclimate.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -engines: - gofmt: - enabled: true - govet: - enabled: true - golint: - enabled: true -ratings: - paths: - - "**.go" diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..2e7a32d9 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,5 @@ +# These are supported funding model platforms + +github: [jinzhu] +patreon: jinzhu +open_collective: gorm diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..e4e81074 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +--- +version: 2 +updates: + - package-ecosystem: gomod + directory: / + schedule: + interval: weekly + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly + - package-ecosystem: gomod + directory: /tests + schedule: + interval: weekly diff --git a/.github/labels.json b/.github/labels.json new file mode 100644 index 00000000..5c7eb7d1 --- /dev/null +++ b/.github/labels.json @@ -0,0 +1,166 @@ +{ + "labels": { + "critical": { + "name": "type:critical", + "colour": "#E84137", + "description": "critical questions" + }, + "question": { + "name": "type:question", + "colour": "#EDEDED", + "description": "general questions" + }, + "feature": { + "name": "type:feature_request", + "colour": "#43952A", + "description": "feature request" + }, + "invalid_question": { + "name": "type:invalid question", + "colour": "#CF2E1F", + "description": "invalid question (not related to GORM or described in document or not enough information provided)" + }, + "with_playground": { + "name": "type:with reproduction steps", + "colour": "#00ff00", + "description": "with reproduction steps" + }, + "without_playground": { + "name": "type:missing reproduction steps", + "colour": "#CF2E1F", + "description": "missing reproduction steps" + }, + "has_pr": { + "name": "type:has pull request", + "colour": "#43952A", + "description": "has pull request" + }, + "not_tested": { + "name": "type:not tested", + "colour": "#CF2E1F", + "description": "not tested" + }, + "tested": { + "name": "type:tested", + "colour": "#00ff00", + "description": "tested" + }, + "breaking_change": { + "name": "type:breaking change", + "colour": "#CF2E1F", + "description": "breaking change" + } + }, + "issue": { + "with_playground": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/github.com\/go-gorm\/playground\/pull\/\\d\\d+/s" + } + ] + }, + "critical": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/(critical|urgent)/i" + }, + { + "type": "titleMatches", + "pattern": "/(critical|urgent)/i" + } + ] + }, + "question": { + "requires": 1, + "conditions": [ + { + "type": "titleMatches", + "pattern": "/question/i" + }, + { + "type": "descriptionMatches", + "pattern": "/question/i" + } + ] + }, + "feature": { + "requires": 1, + "conditions": [ + { + "type": "titleMatches", + "pattern": "/feature/i" + }, + { + "type": "descriptionMatches", + "pattern": "/Describe the feature/i" + } + ] + }, + "without_playground": { + "requires": 6, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/^((?!github.com\/go-gorm\/playground\/pull\/\\d\\d+).)*$/s" + }, + { + "type": "titleMatches", + "pattern": "/^((?!question).)*$/s" + }, + { + "type": "descriptionMatches", + "pattern": "/^((?!question).)*$/is" + }, + { + "type": "descriptionMatches", + "pattern": "/^((?!Describe the feature).)*$/is" + }, + { + "type": "titleMatches", + "pattern": "/^((?!critical|urgent).)*$/s" + }, + { + "type": "descriptionMatches", + "pattern": "/^((?!critical|urgent).)*$/s" + } + ] + } + }, + "pr": { + "critical": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/(critical|urgent)/i" + }, + { + "type": "titleMatches", + "pattern": "/(critical|urgent)/i" + } + ] + }, + "not_tested": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/\\[\\] Tested/" + } + ] + }, + "breaking_change": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/\\[\\] Non breaking API changes/" + } + ] + } + } +} diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml new file mode 100644 index 00000000..fbebfc12 --- /dev/null +++ b/.github/workflows/invalid_question.yml @@ -0,0 +1,28 @@ +name: "Close invalid questions issues" +on: + schedule: + - cron: "*/10 * * * *" + +permissions: + contents: read + +jobs: + stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v8 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 30 + remove-stale-when-updated: true + only-labels: "type:invalid question" + diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 00000000..0e8aaa60 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,19 @@ +name: "Issue Labeler" +on: + issues: + types: [opened, edited, reopened] + pull_request: + types: [opened, edited, reopened] + +jobs: + triage: + runs-on: ubuntu-latest + name: Label issues and pull requests + steps: + - name: check out + uses: actions/checkout@v3 + + - name: labeler + uses: jinzhu/super-labeler-action@develop + with: + GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml new file mode 100644 index 00000000..b23a5bf9 --- /dev/null +++ b/.github/workflows/missing_playground.yml @@ -0,0 +1,27 @@ +name: "Close Missing Playground issues" +on: + schedule: + - cron: "*/10 * * * *" + +permissions: + contents: read + +jobs: + stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v8 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 30 + remove-stale-when-updated: true + only-labels: "type:missing reproduction steps" diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml new file mode 100644 index 00000000..a6542d57 --- /dev/null +++ b/.github/workflows/reviewdog.yml @@ -0,0 +1,22 @@ +name: reviewdog +on: [pull_request] +jobs: + golangci-lint: + name: runner / golangci-lint + runs-on: ubuntu-latest + steps: + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + - name: golangci-lint + uses: reviewdog/action-golangci-lint@v2 + + - name: Setup reviewdog + uses: reviewdog/action-setup@v1 + + - name: gofumpt -s with reviewdog + env: + REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + go install mvdan.cc/gofumpt@v0.2.0 + gofumpt -e -d . | \ + reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review \ No newline at end of file diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000..c9752883 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,28 @@ +name: "Stale" +on: + schedule: + - cron: "0 2 * * *" + +permissions: + contents: read + +jobs: + stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v8 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" + days-before-stale: 360 + days-before-close: 180 + stale-issue-label: "status:stale" + exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' + stale-pr-label: 'status:stale' + exempt-pr-labels: 'type:feature,type:with reproduction steps,type:has pull request' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..bf225d42 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,202 @@ +name: tests + +on: + push: + branches-ignore: + - 'gh-pages' + pull_request: + branches-ignore: + - 'gh-pages' + +permissions: + contents: read + +jobs: + # Label of the container job + sqlite: + strategy: + matrix: + go: ['1.19', '1.18'] + platform: [ubuntu-latest] # can not run in windows OS + runs-on: ${{ matrix.platform }} + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + - name: go mod package cache + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh + + mysql: + strategy: + matrix: + dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] + go: ['1.19', '1.18'] + platform: [ubuntu-latest] + runs-on: ${{ matrix.platform }} + + services: + mysql: + image: ${{ matrix.dbversion }} + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9910:3306 + options: >- + --health-cmd "mysqladmin ping -ugorm -pgorm" + --health-interval 10s + --health-start-period 10s + --health-timeout 5s + --health-retries 10 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + + - name: go mod package cache + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + postgres: + strategy: + matrix: + dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] + go: ['1.19', '1.18'] + platform: [ubuntu-latest] # can not run in macOS and Windows + runs-on: ${{ matrix.platform }} + + services: + postgres: + image: ${{ matrix.dbversion }} + env: + POSTGRES_PASSWORD: gorm + POSTGRES_USER: gorm + POSTGRES_DB: gorm + TZ: Asia/Shanghai + ports: + - 9920:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + - name: go mod package cache + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + + sqlserver: + strategy: + matrix: + go: ['1.19', '1.18'] + platform: [ubuntu-latest] # can not run test in macOS and windows + runs-on: ${{ matrix.platform }} + + services: + mssql: + image: mcmoe/mssqldocker:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 + ports: + - 9930:1433 + options: >- + --health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1" + --health-start-period 10s + --health-interval 10s + --health-timeout 5s + --health-retries 10 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + - name: go mod package cache + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + + tidb: + strategy: + matrix: + dbversion: [ 'v6.5.0' ] + go: [ '1.19', '1.18' ] + platform: [ ubuntu-latest ] + runs-on: ${{ matrix.platform }} + + steps: + - name: Setup TiDB + uses: Icemap/tidb-action@main + with: + port: 9940 + version: ${{matrix.dbversion}} + + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + + - name: go mod package cache + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh diff --git a/.gitignore b/.gitignore index 01dc5ce0..72733326 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,7 @@ +TODO* documents +coverage.txt _book +.idea +vendor +.vscode diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..b88bf672 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,20 @@ +linters: + enable: + - cyclop + - exportloopref + - gocritic + - gosec + - ineffassign + - misspell + - prealloc + - unconvert + - unparam + - goimports + - whitespace + +linters-settings: + whitespace: + multi-func: true + goimports: + local-prefixes: gorm.io/gorm + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index c54d572d..00000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,52 +0,0 @@ -# How to Contribute - -## Bug Report - -- Do a search on GitHub under Issues in case it has already been reported -- Submit __executable script__ or failing test pull request that could demonstrates the issue is *MUST HAVE* - -## Feature Request - -- Feature request with pull request is welcome -- Or it won't be implemented until I (other developers) find it is helpful for my (their) daily work - -## Pull Request - -- Prefer single commit pull request, that make the git history can be a bit easier to follow. -- New features need to be covered with tests to make sure your code works as expected, and won't be broken by others in future - -## Contributing to Documentation - -- You are welcome ;) -- You can help improve the README by making them more coherent, consistent or readable, and add more godoc documents to make people easier to follow. -- Blogs & Usage Guides & PPT also welcome, please add them to https://github.com/jinzhu/gorm/wiki/Guides - -### Executable script template - -```go -package main - -import ( - _ "github.com/mattn/go-sqlite3" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - "github.com/jinzhu/gorm" -) - -var db gorm.DB - -func init() { - var err error - db, err = gorm.Open("sqlite3", "test.db") - // db, err := gorm.Open("postgres", "user=username dbname=password sslmode=disable") - // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True") - if err != nil { - panic(err) - } - db.LogMode(true) -} - -func main() { - // Your code -} -``` diff --git a/README.md b/README.md index 7dba9052..85ad3050 100644 --- a/README.md +++ b/README.md @@ -2,45 +2,43 @@ The fantastic ORM library for Golang, aims to be developer friendly. -[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) -[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) +[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) +[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) +[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) ## Overview -* Full-Featured ORM (almost) -* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) -* Callbacks (Before/After Create/Save/Update/Delete/Find) -* Preloading (eager loading) -* Transactions +* Full-Featured ORM +* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance) +* Hooks (Before/After Create/Save/Update/Delete/Find) +* Eager loading with `Preload`, `Joins` +* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point +* Context, Prepared Statement Mode, DryRun Mode +* Batch Insert, FindInBatches, Find To Map +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key -* SQL Builder * Auto Migrations * Logger -* Extendable, write Plugins based on GORM callbacks +* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… * Every feature comes with tests * Developer Friendly ## Getting Started -* GORM Guides [jinzhu.github.com/gorm](http://jinzhu.github.io/gorm) +* GORM Guides [https://gorm.io](https://gorm.io) +* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) -## Upgrading To V1.0 +## Contributing -* [CHANGELOG](http://jinzhu.github.io/gorm/changelog.html) +[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) -# Author +## Contributors -**jinzhu** - -* -* -* - -# Contributors - -https://github.com/jinzhu/gorm/graphs/contributors +[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework! ## License -Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). +© Jinzhu, 2013~time.Now + +Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) diff --git a/association.go b/association.go index 0f94683d..7c93ebea 100644 --- a/association.go +++ b/association.go @@ -1,371 +1,579 @@ package gorm import ( - "errors" "fmt" "reflect" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // Association Mode contains some helper methods to handle relationship things easily. type Association struct { - Error error - scope *Scope - column string - field *Field + DB *DB + Relationship *schema.Relationship + Unscope bool + Error error } -// Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) - return association.setErr(association.scope.db.Error) +func (db *DB) Association(column string) *Association { + association := &Association{DB: db} + table := db.Statement.Table + + if err := db.Statement.Parse(db.Statement.Model); err == nil { + db.Statement.Table = table + association.Relationship = db.Statement.Schema.Relationships.Relations[column] + + if association.Relationship == nil { + association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column) + } + + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } + } else { + association.Error = err + } + + return association } -// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to -func (association *Association) Append(values ...interface{}) *Association { - if association.Error != nil { - return association +func (association *Association) Unscoped() *Association { + return &Association{ + DB: association.DB, + Relationship: association.Relationship, + Error: association.Error, + Unscope: true, } +} - if relationship := association.field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) +func (association *Association) Find(out interface{}, conds ...interface{}) error { + if association.Error == nil { + association.Error = association.buildCondition().Find(out, conds...).Error } - return association.saveAssociations(values...) + return association.Error } -// Replace replace current associations with new one -func (association *Association) Replace(values ...interface{}) *Association { - if association.Error != nil { - return association +func (association *Association) Append(values ...interface{}) error { + if association.Error == nil { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + if len(values) > 0 { + association.Error = association.Replace(values...) + } + default: + association.saveAssociation( /*clear*/ false, values...) + } } - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) + return association.Error +} - // Append new values - association.field.Set(reflect.Zero(association.field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null when clearing value (length equals 0) - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil +func (association *Association) Replace(values ...interface{}) error { + if association.Error == nil { + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + + var oldBelongsToExpr clause.Expression + // we have to record the old BelongsTo value + if association.Unscope && rel.Type == schema.BelongsTo { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + oldBelongsToExpr = clause.IN{Column: column, Values: values} } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) } - } else { - // Polymorphic Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + + // save associations + if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { + return association.Error } - // Delete Relations except new created - if len(values) > 0 { - var associationForeignFieldNames []string - if relationship.Kind == "many_to_many" { - // if many to many relations, get association fields name from association foreign keys - associationScope := scope.New(reflect.New(field.Type()).Interface()) - for _, dbName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(dbName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + // set old associations's foreign key to null + switch rel.Type { + case schema.BelongsTo: + if len(values) == 0 { + updateMap := map[string]interface{}{} + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } + case reflect.Struct: + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } - } else { - // If other relations, use primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + + for _, ref := range rel.References { + updateMap[ref.ForeignKey.DBName] = nil } - } - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) + association.Error = association.DB.UpdateColumns(updateMap).Error + } + if association.Unscope && oldBelongsToExpr != nil { + association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } + case schema.HasOne, schema.HasMany: + var ( + primaryFields []*schema.Field + foreignKeys []string + updateMap = map[string]interface{}{} + relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) + if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { + tx.Not(clause.IN{Column: column, Values: values}) + } } - } - if relationship.Kind == "many_to_many" { - // if many to many relations, delete related relations from join table - var sourceForeignFieldNames []string + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateMap[ref.ForeignKey.DBName] = nil + } else if ref.PrimaryValue != "" { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } - for _, dbName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) + if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) + if association.Unscope { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error + } else { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error } } + case schema.Many2Many: + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { - newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) + if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { + tx.Where(clause.IN{Column: column, Values: values}) + } else { + return ErrPrimaryKeyRequired } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) + if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { + tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + association.Error = tx.Delete(modelValue).Error } } - return association + return association.Error } -// Delete remove relationship between source & passed arguments, but won't delete those arguments -func (association *Association) Delete(values ...interface{}) *Association { - if association.Error != nil { - return association - } +func (association *Association) Delete(values ...interface{}) error { + if association.Error == nil { + var ( + reflectValue = association.DB.Statement.ReflectValue + rel = association.Relationship + primaryFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} + conds []clause.Expression + ) - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateAttrs[ref.ForeignKey.DBName] = nil + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } - if len(values) == 0 { - return association - } + switch rel.Type { + case schema.BelongsTo: + associationDB := association.DB.Session(&Session{}) + tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface()) - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) + if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) + relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } } - } + case schema.HasOne, schema.HasMany: + model := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := association.DB.Model(model) - // get association's foreign fields name - var associationScope = scope.New(reflect.New(field.Type()).Interface()) - var associationForeignFieldNames []string - for _, associationDBName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(associationDBName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) + if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired } - } - // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) + relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., + if association.Unscope { + association.Error = tx.Clauses(conds...).Delete(model).Error + } else { + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + } + case schema.Many2Many: + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + joinValue = reflect.New(rel.JoinTable.ModelType).Interface() ) - // set foreign key to be null if there are some records affected - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap) + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } + } + + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) + if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) } else { - association.setErr(results.Error) + return ErrPrimaryKeyRequired } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) + relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error } - } - // Remove deleted records from source's field - if association.Error == nil { - if field.Kind() == reflect.Slice { - leftValues := reflect.Zero(field.Type()) - - for i := 0; i < field.Len(); i++ { - reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var isDeleted = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - isDeleted = true - break + if association.Error == nil { + // clean up deleted values's foreign key + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) + + cleanUpDeletedRelations := func(data reflect.Value) { + if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) + primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) + + switch fieldValue.Kind() { + case reflect.Slice, reflect.Array: + validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) + for i := 0; i < fieldValue.Len(); i++ { + for idx, field := range rel.FieldSchema.PrimaryFields { + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) + } + + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { + validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) + } + } + + association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) + case reflect.Struct: + for idx, field := range rel.FieldSchema.PrimaryFields { + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) + } + + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { + if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + break + } + + if rel.JoinTable == nil { + for _, ref := range rel.References { + if ref.OwnPrimaryKey || ref.PrimaryValue != "" { + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } else { + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } + } + } } } - if !isDeleted { - leftValues = reflect.Append(leftValues, reflectValue) - } } - association.field.Set(leftValues) - } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.field.Set(reflect.Zero(field.Type())) - break + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i))) } + case reflect.Struct: + cleanUpDeletedRelations(reflectValue) } } } - return association + return association.Error } -// Clear remove relationship between source & current associations, won't delete those associations -func (association *Association) Clear() *Association { +func (association *Association) Clear() error { return association.Replace() } -// Count return the count of current associations -func (association *Association) Count() int { - var ( - count = 0 - relationship = association.field.Relationship - scope = association.scope - fieldValue = association.field.Field.Interface() - query = scope.DB() - ) - - if relationship.Kind == "many_to_many" { - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } else if relationship.Kind == "belongs_to" { - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } - - if relationship.PolymorphicType != "" { - query = query.Where( - fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - scope.TableName(), - ) +func (association *Association) Count() (count int64) { + if association.Error == nil { + association.Error = association.buildCondition().Count(&count).Error } + return +} - query.Model(fieldValue).Count(&count) - return count +type assignBack struct { + Source reflect.Value + Index int + Dest reflect.Value } -// saveAssociations save passed values as associations -func (association *Association) saveAssociations(values ...interface{}) *Association { +func (association *Association) saveAssociation(clear bool, values ...interface{}) { var ( - scope = association.scope - field = association.field - relationship = field.Relationship + reflectValue = association.DB.Statement.ReflectValue + assignBacks []assignBack // assign association values back to arguments after save ) - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr + appendToRelations := func(source, rv reflect.Value, clear bool) { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + switch rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() > 0 { + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) + + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) + } + } + case reflect.Struct: + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) + + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) + } + } + case schema.HasMany, schema.Many2Many: + elemType := association.Relationship.Field.IndirectFieldType.Elem() + oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + var fieldValue reflect.Value + if clear { + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap()) + } else { + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap()) + reflect.Copy(fieldValue, oldFieldValue) + } + + appendToFieldValues := func(ev reflect.Value) { + if ev.Type().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev) + } else if ev.Type().Elem().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev.Elem()) + } else { + association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name) + } + + if elemType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()}) + } + } + + switch rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + } + case reflect.Struct: + appendToFieldValues(rv.Addr()) + } + + if association.Error == nil { + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) + } } + } - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) + selectedSaveColumns := []string{association.Relationship.Name} + omitColumns := []string{} + selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false) + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, association.Relationship.Name) { + if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" { + columnName = name } + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name } - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) + if columnName != "" { + if ok { + selectedSaveColumns = append(selectedSaveColumns, columnName) + } else { + omitColumns = append(omitColumns, columnName) } } + } - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) + } + } - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) + associationDB := association.DB.Session(&Session{}).Model(nil) + if !association.DB.FullSaveAssociations { + associationDB.Select(selectedSaveColumns) + } + if len(omitColumns) > 0 { + associationDB.Omit(omitColumns...) + } + associationDB = associationDB.Session(&Session{}) + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(values) != reflectValue.Len() { + // clear old data + if clear && len(values) == 0 { + for i := 0; i < reflectValue.Len(); i++ { + if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + association.Error = err + break + } + + if association.Relationship.JoinTable == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + association.Error = err + break + } + } + } + } + } + break } + + association.Error = ErrInvalidValueOfLength + return } - } - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) + for i := 0; i < reflectValue.Len(); i++ { + appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) + + // TODO support save slice data, sql with case? + association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error + } + case reflect.Struct: + // clear old data + if clear && len(values) == 0 { + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + + if association.Relationship.JoinTable == nil && association.Error == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } } + } + + for idx, value := range values { + rv := reflect.Indirect(reflect.ValueOf(value)) + appendToRelations(reflectValue, rv, clear && idx == 0) + } + + if len(values) > 0 { + association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error + } + } + + for _, assignBack := range assignBacks { + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) + if assignBack.Index > 0 { + reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) } else { - association.setErr(errors.New("invalid value type")) + reflect.Indirect(assignBack.Dest).Set(fieldValue) } } - return association } -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err +func (association *Association) buildCondition() *DB { + var ( + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE") + if len(joinStmt.SQL.String()) > 0 { + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + } + + tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) } - return association + + return tx } diff --git a/association_test.go b/association_test.go deleted file mode 100644 index ad56d84e..00000000 --- a/association_test.go +++ /dev/null @@ -1,874 +0,0 @@ -package gorm_test - -import ( - "fmt" - "os" - "reflect" - "sort" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestBelongsTo(t *testing.T) { - post := Post{ - Title: "post belongs to", - Body: "body belongs to", - Category: Category{Name: "Category 1"}, - MainCategory: Category{Name: "Main Category 1"}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - if post.Category.ID == 0 || post.MainCategory.ID == 0 { - t.Errorf("Category's primary key should be updated") - } - - if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 { - t.Errorf("post's foreign key should be updated") - } - - // Query - var category1 Category - DB.Model(&post).Association("Category").Find(&category1) - if category1.Name != "Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var mainCategory1 Category - DB.Model(&post).Association("MainCategory").Find(&mainCategory1) - if mainCategory1.Name != "Main Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var category11 Category - DB.Model(&post).Related(&category11) - if category11.Name != "Category 1" { - t.Errorf("Query belongs to relations with Related") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - if DB.Model(&post).Association("MainCategory").Count() != 1 { - t.Errorf("Post's main category count should be 1") - } - - // Append - var category2 = Category{ - Name: "Category 2", - } - DB.Model(&post).Association("Category").Append(&category2) - - if category2.ID == 0 { - t.Errorf("Category should has ID when created with Append") - } - - var category21 Category - DB.Model(&post).Related(&category21) - - if category21.Name != "Category 2" { - t.Errorf("Category should be updated with Append") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Replace - var category3 = Category{ - Name: "Category 3", - } - DB.Model(&post).Association("Category").Replace(&category3) - - if category3.ID == 0 { - t.Errorf("Category should has ID when created with Replace") - } - - var category31 Category - DB.Model(&post).Related(&category31) - if category31.Name != "Category 3" { - t.Errorf("Category should be updated with Replace") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Delete - DB.Model(&post).Association("Category").Delete(&category2) - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not delete any category when Delete a unrelated Category") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should not be reseted when Delete a unrelated Category") - } - - DB.Model(&post).Association("Category").Delete(&category3) - - if post.Category.Name != "" { - t.Errorf("Post's category should be reseted after Delete") - } - - var category41 Category - DB.Model(&post).Related(&category41) - if category41.Name != "" { - t.Errorf("Category should be deleted with Delete") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Delete, but got %v", count) - } - - // Clear - DB.Model(&post).Association("Category").Append(&Category{ - Name: "Category 2", - }) - - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should find category after append") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should has value after Append") - } - - DB.Model(&post).Association("Category").Clear() - - if post.Category.Name != "" { - t.Errorf("Post's category should be cleared after Clear") - } - - if !DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not find any category after Clear") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Clear, but got %v", count) - } - - // Check Association mode with soft delete - category6 := Category{ - Name: "Category 6", - } - DB.Model(&post).Association("Category").Append(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 after Append, but got %v", count) - } - - DB.Delete(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) - } - - if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { - t.Errorf("Post's category is not findable after Delete") - } - - if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { - t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) - } -} - -func TestBelongsToOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileRefer"` - ProfileRefer int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestBelongsToOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Refer string - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"` - ProfileID int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOne(t *testing.T) { - user := User{ - Name: "has one", - CreditCard: CreditCard{Number: "411111111111"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Error("Got errors when save user", err.Error()) - } - - if user.CreditCard.UserId.Int64 == 0 { - t.Errorf("CreditCard's foreign key should be updated") - } - - // Query - var creditCard1 CreditCard - DB.Model(&user).Related(&creditCard1) - - if creditCard1.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - var creditCard11 CreditCard - DB.Model(&user).Association("CreditCard").Find(&creditCard11) - - if creditCard11.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Append - var creditcard2 = CreditCard{ - Number: "411111111112", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard2) - - if creditcard2.ID == 0 { - t.Errorf("Creditcard should has ID when created with Append") - } - - var creditcard21 CreditCard - DB.Model(&user).Related(&creditcard21) - if creditcard21.Number != "411111111112" { - t.Errorf("CreditCard should be updated with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Replace - var creditcard3 = CreditCard{ - Number: "411111111113", - } - DB.Model(&user).Association("CreditCard").Replace(&creditcard3) - - if creditcard3.ID == 0 { - t.Errorf("Creditcard should has ID when created with Replace") - } - - var creditcard31 CreditCard - DB.Model(&user).Related(&creditcard31) - if creditcard31.Number != "411111111113" { - t.Errorf("CreditCard should be updated with Replace") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Delete - DB.Model(&user).Association("CreditCard").Delete(&creditcard2) - var creditcard4 CreditCard - DB.Model(&user).Related(&creditcard4) - if creditcard4.Number != "411111111113" { - t.Errorf("Should not delete credit card when Delete a unrelated CreditCard") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Delete(&creditcard3) - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should delete credit card with Delete") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Delete") - } - - // Clear - var creditcard5 = CreditCard{ - Number: "411111111115", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard5) - - if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should added credit card with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Clear() - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Credit card should be deleted with Clear") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Clear") - } - - // Check Association mode with soft delete - var creditcard6 = CreditCard{ - Number: "411111111116", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 after Append, but got %v", count) - } - - DB.Delete(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { - t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) - } - - if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { - t.Errorf("User's creditcard is not findable after Delete") - } - - if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { - t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) - } -} - -func TestHasOneOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOneOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasMany(t *testing.T) { - post := Post{ - Title: "post has many", - Body: "body has many", - Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - for _, comment := range post.Comments { - if comment.PostId == 0 { - t.Errorf("comment's PostID should be updated") - } - } - - var compareComments = func(comments []Comment, contents []string) bool { - var commentContents []string - for _, comment := range comments { - commentContents = append(commentContents, comment.Content) - } - sort.Strings(commentContents) - sort.Strings(contents) - return reflect.DeepEqual(commentContents, contents) - } - - // Query - if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { - t.Errorf("Comment 1 should be saved") - } - - var comments1 []Comment - DB.Model(&post).Association("Comments").Find(&comments1) - if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Association") - } - - var comments11 []Comment - DB.Model(&post).Related(&comments11) - if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Related") - } - - if DB.Model(&post).Association("Comments").Count() != 2 { - t.Errorf("Post's comments count should be 2") - } - - // Append - DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"}) - - var comments2 []Comment - DB.Model(&post).Related(&comments2) - if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) { - t.Errorf("Append new record to has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 3 { - t.Errorf("Post's comments count should be 3 after Append") - } - - // Delete - DB.Model(&post).Association("Comments").Delete(comments11) - - var comments3 []Comment - DB.Model(&post).Related(&comments3) - if !compareComments(comments3, []string{"Comment 3"}) { - t.Errorf("Delete an existing resource for has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 1 { - t.Errorf("Post's comments count should be 1 after Delete 2") - } - - // Replace - DB.Model(&Post{Id: 999}).Association("Comments").Replace() - - var comments4 []Comment - DB.Model(&post).Related(&comments4) - if len(comments4) == 0 { - t.Errorf("Replace for other resource should not clear all comments") - } - - DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"}) - - var comments41 []Comment - DB.Model(&post).Related(&comments41) - if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) { - t.Errorf("Replace has many relations") - } - - // Clear - DB.Model(&Post{Id: 999}).Association("Comments").Clear() - - var comments5 []Comment - DB.Model(&post).Related(&comments5) - if len(comments5) == 0 { - t.Errorf("Clear should not clear all comments") - } - - DB.Model(&post).Association("Comments").Clear() - - var comments51 []Comment - DB.Model(&post).Related(&comments51) - if len(comments51) != 0 { - t.Errorf("Clear has many relations") - } - - // Check Association mode with soft delete - var comment6 = Comment{ - Content: "comment 6", - } - DB.Model(&post).Association("Comments").Append(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 after Append, but got %v", count) - } - - DB.Delete(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 0 { - t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) - } - - var comments6 []Comment - if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { - t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) - } - - if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) - } - - var comments61 []Comment - if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) - } -} - -func TestHasManyOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile []Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasManyOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestManyToMany(t *testing.T) { - DB.Raw("delete from languages") - var languages = []Language{{Name: "ZH"}, {Name: "EN"}} - user := User{Name: "Many2Many", Languages: languages} - DB.Save(&user) - - // Query - var newLanguages []Language - DB.Model(&user).Related(&newLanguages, "Languages") - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Query many to many relations") - } - - DB.Model(&user).Association("Languages").Find(&newLanguages) - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Should be able to find many to many relations") - } - - if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { - t.Errorf("Count should return correct result") - } - - // Append - DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) - if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { - t.Errorf("New record should be saved when append") - } - - languageA := Language{Name: "AA"} - DB.Save(&languageA) - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) - - languageC := Language{Name: "CC"} - DB.Save(&languageC) - DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) - - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) - - totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) { - t.Errorf("All appended languages should be saved") - } - - // Delete - user.Languages = []Language{} - DB.Model(&user).Association("Languages").Find(&user.Languages) - - var language Language - DB.Where("name = ?", "EE").First(&language) - DB.Model(&user).Association("Languages").Delete(language, &language) - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { - t.Errorf("Relations should be deleted with Delete") - } - if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { - t.Errorf("Language EE should not be deleted") - } - - DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) - - user2 := User{Name: "Many2Many_User2", Languages: languages} - DB.Save(&user2) - - DB.Model(&user).Association("Languages").Delete(languages, &languages) - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 { - t.Errorf("Relations should be deleted with Delete") - } - - if DB.Model(&user2).Association("Languages").Count() == 0 { - t.Errorf("Other user's relations should not be deleted") - } - - // Replace - var languageB Language - DB.Where("name = ?", "BB").First(&languageB) - DB.Model(&user).Association("Languages").Replace(languageB) - if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 { - t.Errorf("Relations should be replaced") - } - - DB.Model(&user).Association("Languages").Replace() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be replaced with empty") - } - - DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) - if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { - t.Errorf("Relations should be replaced") - } - - // Clear - DB.Model(&user).Association("Languages").Clear() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be cleared") - } - - // Check Association mode with soft delete - var language6 = Language{ - Name: "language 6", - } - DB.Model(&user).Association("Languages").Append(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 after Append, but got %v", count) - } - - DB.Delete(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 0 { - t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) - } - - var languages6 []Language - if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { - t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) - } - - if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) - } - - var languages61 []Language - if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) - } -} - -func TestRelated(t *testing.T) { - user := User{ - Name: "jinzhu", - BillingAddress: Address{Address1: "Billing Address - Address 1"}, - ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, - Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, - CreditCard: CreditCard{Number: "1234567890"}, - Company: Company{Name: "company1"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Errorf("No error should happen when saving user") - } - - if user.CreditCard.ID == 0 { - t.Errorf("After user save, credit card should have id") - } - - if user.BillingAddress.ID == 0 { - t.Errorf("After user save, billing address should have id") - } - - if user.Emails[0].Id == 0 { - t.Errorf("After user save, billing address should have id") - } - - var emails []Email - DB.Model(&user).Related(&emails) - if len(emails) != 2 { - t.Errorf("Should have two emails") - } - - var emails2 []Email - DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) - if len(emails2) != 1 { - t.Errorf("Should have two emails") - } - - var emails3 []*Email - DB.Model(&user).Related(&emails3) - if len(emails3) != 2 { - t.Errorf("Should have two emails") - } - - var user1 User - DB.Model(&user).Related(&user1.Emails) - if len(user1.Emails) != 2 { - t.Errorf("Should have only one email match related condition") - } - - var address1 Address - DB.Model(&user).Related(&address1, "BillingAddressId") - if address1.Address1 != "Billing Address - Address 1" { - t.Errorf("Should get billing address from user correctly") - } - - user1 = User{} - DB.Model(&address1).Related(&user1, "BillingAddressId") - if DB.NewRecord(user1) { - t.Errorf("Should get user from address correctly") - } - - var user2 User - DB.Model(&emails[0]).Related(&user2) - if user2.Id != user.Id || user2.Name != user.Name { - t.Errorf("Should get user from email correctly") - } - - var creditcard CreditCard - var user3 User - DB.First(&creditcard, "number = ?", "1234567890") - DB.Model(&creditcard).Related(&user3) - if user3.Id != user.Id || user3.Name != user.Name { - t.Errorf("Should get user from credit card correctly") - } - - if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() { - t.Errorf("RecordNotFound for Related") - } - - var company Company - if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" { - t.Errorf("RecordNotFound for Related") - } -} - -func TestForeignKey(t *testing.T) { - for _, structField := range DB.NewScope(&User{}).GetStructFields() { - for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Email{}).GetStructFields() { - for _, foreignKey := range []string{"UserId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Post{}).GetStructFields() { - for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Comment{}).GetStructFields() { - for _, foreignKey := range []string{"PostId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } -} - -func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { - // sqlite does not support ADD CONSTRAINT in ALTER TABLE - return - } - targetScope := DB.NewScope(target) - targetTableName := targetScope.TableName() - modelScope := DB.NewScope(source) - modelField, ok := modelScope.FieldByName(sourceFieldName) - if !ok { - t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName)) - } - targetField, ok := targetScope.FieldByName(targetFieldName) - if !ok { - t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName)) - } - dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) - err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error - if err != nil { - t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) - } -} - -func TestLongForeignKey(t *testing.T) { - testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID") -} - -func TestLongForeignKeyWithShortDest(t *testing.T) { - testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID") -} diff --git a/callback.go b/callback.go deleted file mode 100644 index 93198a71..00000000 --- a/callback.go +++ /dev/null @@ -1,237 +0,0 @@ -package gorm - -import ( - "fmt" -) - -// DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{} - -// Callback is a struct that contains all CURD callbacks -// Field `creates` contains callbacks will be call when creating object -// Field `updates` contains callbacks will be call when updating object -// Field `deletes` contains callbacks will be call when deleting object -// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... -// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... -// Field `processors` contains all callback processors, will be used to generate above callbacks in order -type Callback struct { - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) - processors []*CallbackProcessor -} - -// CallbackProcessor contains callback informations -type CallbackProcessor struct { - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler - parent *Callback -} - -func (c *Callback) clone() *Callback { - return &Callback{ - creates: c.creates, - updates: c.updates, - deletes: c.deletes, - queries: c.queries, - rowQueries: c.rowQueries, - processors: c.processors, - } -} - -// Create could be used to register callbacks for creating object -// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { -// // business logic -// ... -// -// // set error if some thing wrong happened, will rollback the creating -// scope.Err(errors.New("error")) -// }) -func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{kind: "create", parent: c} -} - -// Update could be used to register callbacks for updating object, refer `Create` for usage -func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{kind: "update", parent: c} -} - -// Delete could be used to register callbacks for deleting object, refer `Create` for usage -func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{kind: "delete", parent: c} -} - -// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... -// Refer `Create` for usage -func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{kind: "query", parent: c} -} - -// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage -func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{kind: "row_query", parent: c} -} - -// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { - cp.after = callbackName - return cp -} - -// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { - cp.before = callbackName - return cp -} - -// Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { - cp.name = callbackName - cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Remove a registered callback -// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") -func (cp *CallbackProcessor) Remove(callbackName string) { - fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) - cp.name = callbackName - cp.remove = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Replace a registered callback with new callback -// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("Created", now) -// scope.SetColumn("Updated", now) -// }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) - cp.name = callbackName - cp.processor = &callback - cp.replace = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Get registered callback -// db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { - for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind && !cp.remove { - return *p.processor - } - } - return nil -} - -// getRIndex get right index from string slice -func getRIndex(strs []string, str string) int { - for i := len(strs) - 1; i >= 0; i-- { - if strs[i] == str { - return i - } - } - return -1 -} - -// sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { - var ( - allNames, sortedNames []string - sortCallbackProcessor func(c *CallbackProcessor) - ) - - for _, cp := range cps { - // show warning message the callback name already exists - if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) - } - allNames = append(allNames, cp.name) - } - - sortCallbackProcessor = func(c *CallbackProcessor) { - if getRIndex(sortedNames, c.name) == -1 { // if not sorted - if c.before != "" { // if defined before callback - if index := getRIndex(sortedNames, c.before); index != -1 { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(allNames, c.before); index != -1 { - // if before callback exists but haven't sorted, append current callback to last - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } - } - - if c.after != "" { // if defined after callback - if index := getRIndex(sortedNames, c.after); index != -1 { - // if after callback already sorted, append current callback just before it - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(allNames, c.after); index != -1 { - // if after callback exists but haven't sorted - cp := cps[index] - // set after callback's before callback to current callback - if cp.before == "" { - cp.before = c.name - } - sortCallbackProcessor(cp) - } - } - - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } - } - } - - for _, cp := range cps { - sortCallbackProcessor(cp) - } - - var sortedFuncs []*func(scope *Scope) - for _, name := range sortedNames { - if index := getRIndex(allNames, name); !cps[index].remove { - sortedFuncs = append(sortedFuncs, cps[index].processor) - } - } - - return sortedFuncs -} - -// reorder all registered processors, and reset CURD callbacks -func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries []*CallbackProcessor - - for _, processor := range c.processors { - if processor.name != "" { - switch processor.kind { - case "create": - creates = append(creates, processor) - case "update": - updates = append(updates, processor) - case "delete": - deletes = append(deletes, processor) - case "query": - queries = append(queries, processor) - case "row_query": - rowQueries = append(rowQueries, processor) - } - } - } - - c.creates = sortProcessors(creates) - c.updates = sortProcessors(updates) - c.deletes = sortProcessors(deletes) - c.queries = sortProcessors(queries) - c.rowQueries = sortProcessors(rowQueries) -} diff --git a/callback_create.go b/callback_create.go deleted file mode 100644 index 14b82047..00000000 --- a/callback_create.go +++ /dev/null @@ -1,148 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for creating -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - DefaultCallback.Create().Register("gorm:create", createCallback) - DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeCreate") - } -} - -// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { - if !scope.HasError() { - now := NowFunc() - scope.SetColumn("CreatedAt", now) - scope.SetColumn("UpdatedAt", now) - } -} - -// createCallback the callback used to insert data into database -func createCallback(scope *Scope) { - if !scope.HasError() { - defer scope.trace(NowFunc()) - - var ( - columns, placeholders []string - blankColumnsWithDefaultValue []string - ) - - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else if !field.IsPrimaryKey || !field.IsBlank { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) - } - } - } - } - } - - var ( - returningColumn = "*" - quotedTableName = scope.QuotedTableName() - primaryField = scope.PrimaryField() - extraOption string - ) - - if str, ok := scope.Get("gorm:insert_option"); ok { - extraOption = fmt.Sprint(str) - } - - if primaryField != nil { - returningColumn = scope.Quote(primaryField.DBName) - } - - lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) - - if len(columns) == 0 { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v DEFAULT VALUES%v%v", - quotedTableName, - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } else { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v)%v%v", - scope.QuotedTableName(), - strings.Join(columns, ","), - strings.Join(placeholders, ","), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } - - // execute create sql - if lastInsertIDReturningSuffix == "" || primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - } else { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - scope.db.RowsAffected = 1 - } - } - } -} - -// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { - if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) - for _, field := range scope.Fields() { - if field.IsPrimaryKey && !field.IsBlank { - db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) - } - } - db.Scan(scope.Value) - } -} - -// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterCreate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } -} diff --git a/callback_delete.go b/callback_delete.go deleted file mode 100644 index c8ffcc82..00000000 --- a/callback_delete.go +++ /dev/null @@ -1,53 +0,0 @@ -package gorm - -import "fmt" - -// Define callbacks for deleting -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - DefaultCallback.Delete().Register("gorm:delete", deleteCallback) - DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeDelete") - } -} - -// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { - if !scope.HasError() { - var extraOption string - if str, ok := scope.Get("gorm:delete_option"); ok { - extraOption = fmt.Sprint(str) - } - - if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET deleted_at=%v%v%v", - scope.QuotedTableName(), - scope.AddToVars(NowFunc()), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } else { - scope.Raw(fmt.Sprintf( - "DELETE FROM %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterDelete") - } -} diff --git a/callback_query.go b/callback_query.go deleted file mode 100644 index 93782b1d..00000000 --- a/callback_query.go +++ /dev/null @@ -1,93 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Define callbacks for querying -func init() { - DefaultCallback.Query().Register("gorm:query", queryCallback) - DefaultCallback.Query().Register("gorm:preload", preloadCallback) - DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) -} - -// queryCallback used to query data from database -func queryCallback(scope *Scope) { - defer scope.trace(NowFunc()) - - var ( - isSlice, isPtr bool - resultType reflect.Type - results = scope.IndirectValue() - ) - - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryField := scope.PrimaryField(); primaryField != nil { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) - } - } - - if value, ok := scope.Get("gorm:query_destination"); ok { - results = reflect.Indirect(reflect.ValueOf(value)) - } - - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - - if resultType.Kind() == reflect.Ptr { - isPtr = true - resultType = resultType.Elem() - } - } else if kind != reflect.Struct { - scope.Err(errors.New("unsupported destination, should be slice or struct")) - return - } - - scope.prepareQuerySQL() - - if !scope.HasError() { - scope.db.RowsAffected = 0 - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ - - elem := results - if isSlice { - elem = reflect.New(resultType).Elem() - } - - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - - if isSlice { - if isPtr { - results.Set(reflect.Append(results, elem.Addr())) - } else { - results.Set(reflect.Append(results, elem)) - } - } - } - - if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) - } - } - } -} - -// afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterFind") - } -} diff --git a/callback_query_preload.go b/callback_query_preload.go deleted file mode 100644 index d9ec8bdd..00000000 --- a/callback_query_preload.go +++ /dev/null @@ -1,344 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// preloadCallback used to preload associations -func preloadCallback(scope *Scope) { - if scope.Search.preload == nil || scope.HasError() { - return - } - - var ( - preloadedMap = map[string]bool{} - fields = scope.Fields() - ) - - for _, preload := range scope.Search.preload { - var ( - preloadFields = strings.Split(preload.schema, ".") - currentScope = scope - currentFields = fields - ) - - for idx, preloadField := range preloadFields { - var currentPreloadConditions []interface{} - - if currentScope == nil { - continue - } - - // if not preloaded - if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { - - // assign search conditions to last preload - if idx == len(preloadFields)-1 { - currentPreloadConditions = preload.conditions - } - - for _, field := range currentFields { - if field.Name != preloadField || field.Relationship == nil { - continue - } - - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) - case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) - default: - scope.Err(errors.New("unsupported relation")) - } - - preloadedMap[preloadKey] = true - break - } - - if !preloadedMap[preloadKey] { - scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) - return - } - } - - // preload next level - if idx < len(preloadFields)-1 { - currentScope = currentScope.getColumnAsScope(preloadField) - if currentScope != nil { - currentFields = currentScope.Fields() - } - } - } - } -} - -func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { - var ( - preloadDB = scope.NewDB() - preloadConditions []interface{} - ) - - for _, condition := range conditions { - if scopes, ok := condition.(func(*DB) *DB); ok { - preloadDB = scopes(preloadDB) - } else { - preloadConditions = append(preloadConditions, condition) - } - } - - return preloadDB, preloadConditions -} - -// handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, scope.TableName()) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } - } - } - } else { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - scope.Err(field.Set(result)) - } - } -} - -// handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, scope.TableName()) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - preloadMap := make(map[string][]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) - } - - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - if results, ok := preloadMap[toString(objectRealValue)]; ok { - f := object.FieldByName(field.Name) - f.Set(reflect.Append(f, results...)) - } - } - } else { - scope.Err(field.Set(resultsValue)) - } -} - -// handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - var ( - relation = field.Relationship - joinTableHandler = relation.JoinTableHandler - fieldType = field.Struct.Type.Elem() - foreignKeyValue interface{} - foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() - linkHash = map[string][]reflect.Value{} - isPtr bool - ) - - if fieldType.Kind() == reflect.Ptr { - isPtr = true - fieldType = fieldType.Elem() - } - - var sourceKeys = []string{} - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // generate query with join table - newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value).Select("*") - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - // preload inline conditions - if len(preloadConditions) > 0 { - preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) - } - - rows, err := preloadDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - var ( - elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - ) - - // register foreign keys in join tables - var joinTableFields []*Field - for _, sourceKey := range sourceKeys { - joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) - } - - scope.scan(rows, columns, append(fields, joinTableFields...)) - - var foreignKeys = make([]interface{}, len(sourceKeys)) - // generate hashed forkey keys in join table - for idx, joinTableField := range joinTableFields { - if !joinTableField.Field.IsNil() { - foreignKeys[idx] = joinTableField.Field.Elem().Interface() - } - } - hashedSourceKeys := toString(foreignKeys) - - if isPtr { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) - } else { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) - } - } - - // assign find results - var ( - indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string][]reflect.Value{} - foreignFieldNames = []string{} - ) - - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - key := toString(getValueFromFields(object, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) - } - } else if indirectScopeValue.IsValid() { - key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) - } - for source, link := range linkHash { - for i, field := range fieldsSourceMap[source] { - //If not 0 this means Value is a pointer and we already added preloaded models to it - if fieldsSourceMap[source][i].Len() != 0 { - continue - } - field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) - } - - } -} diff --git a/callback_save.go b/callback_save.go deleted file mode 100644 index 5ffe53b9..00000000 --- a/callback_save.go +++ /dev/null @@ -1,92 +0,0 @@ -package gorm - -import "reflect" - -func beginTransactionCallback(scope *Scope) { - scope.Begin() -} - -func commitOrRollbackTransactionCallback(scope *Scope) { - scope.CommitOrRollback() -} - -func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } - for _, field := range scope.Fields() { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } - } -} - -func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } - for _, field := range scope.Fields() { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if relationship := field.Relationship; relationship != nil && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) - - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) - } - - scope.Err(newDB.Save(elem).Error) - - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) - } - scope.Err(scope.NewDB().Save(elem).Error) - } - } - } - } -} diff --git a/callback_system_test.go b/callback_system_test.go deleted file mode 100644 index 13ca3f42..00000000 --- a/callback_system_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package gorm - -import ( - "reflect" - "runtime" - "strings" - "testing" -) - -func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { - var names []string - for _, f := range funcs { - fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") - names = append(names, fnames[len(fnames)-1]) - } - return reflect.DeepEqual(names, fnames) -} - -func create(s *Scope) {} -func beforeCreate1(s *Scope) {} -func beforeCreate2(s *Scope) {} -func afterCreate1(s *Scope) {} -func afterCreate2(s *Scope) {} - -func TestRegisterCallback(t *testing.T) { - var callback = &Callback{} - - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("before_create2", beforeCreate2) - callback.Create().Register("create", create) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Register("after_create2", afterCreate2) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback") - } -} - -func TestRegisterCallbackWithOrder(t *testing.T) { - var callback1 = &Callback{} - callback1.Create().Register("before_create1", beforeCreate1) - callback1.Create().Register("create", create) - callback1.Create().Register("after_create1", afterCreate1) - callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) - if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{} - - callback2.Update().Register("create", create) - callback2.Update().Before("create").Register("before_create1", beforeCreate1) - callback2.Update().After("after_create2").Register("after_create1", afterCreate1) - callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) - callback2.Update().Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } -} - -func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callback1 = &Callback{} - - callback1.Query().Before("after_create1").After("before_create1").Register("create", create) - callback1.Query().Register("before_create1", beforeCreate1) - callback1.Query().Register("after_create1", afterCreate1) - - if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{} - - callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) - callback2.Delete().Before("create").Register("before_create1", beforeCreate1) - callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) - callback2.Delete().Register("after_create1", afterCreate1) - callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback with order") - } -} - -func replaceCreate(s *Scope) {} - -func TestReplaceCallback(t *testing.T) { - var callback = &Callback{} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Replace("create", replaceCreate) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { - t.Errorf("replace callback") - } -} - -func TestRemoveCallback(t *testing.T) { - var callback = &Callback{} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Remove("create") - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { - t.Errorf("remove callback") - } -} diff --git a/callback_update.go b/callback_update.go deleted file mode 100644 index aa27b5fb..00000000 --- a/callback_update.go +++ /dev/null @@ -1,104 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for updating -func init() { - DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - DefaultCallback.Update().Register("gorm:update", updateCallback) - DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { - if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } - } -} - -// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeUpdate") - } - } -} - -// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) - } -} - -// updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { - if !scope.HasError() { - var sqls []string - - if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } - } else { - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, foreignKey := range relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - sqls = append(sqls, - fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) - } - } - } - } - } - } - - var extraOption string - if str, ok := scope.Get("gorm:update_option"); ok { - extraOption = fmt.Sprint(str) - } - - if len(sqls) > 0 { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v%v%v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("AfterUpdate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } - } -} diff --git a/callbacks.go b/callbacks.go new file mode 100644 index 00000000..195d1720 --- /dev/null +++ b/callbacks.go @@ -0,0 +1,341 @@ +package gorm + +import ( + "context" + "errors" + "fmt" + "reflect" + "sort" + "time" + + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func initializeCallbacks(db *DB) *callbacks { + return &callbacks{ + processors: map[string]*processor{ + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, + }, + } +} + +// callbacks gorm callbacks manager +type callbacks struct { + processors map[string]*processor +} + +type processor struct { + db *DB + Clauses []string + fns []func(*DB) + callbacks []*callback +} + +type callback struct { + name string + before string + after string + remove bool + replace bool + match func(*DB) bool + handler func(*DB) + processor *processor +} + +func (cs *callbacks) Create() *processor { + return cs.processors["create"] +} + +func (cs *callbacks) Query() *processor { + return cs.processors["query"] +} + +func (cs *callbacks) Update() *processor { + return cs.processors["update"] +} + +func (cs *callbacks) Delete() *processor { + return cs.processors["delete"] +} + +func (cs *callbacks) Row() *processor { + return cs.processors["row"] +} + +func (cs *callbacks) Raw() *processor { + return cs.processors["raw"] +} + +func (p *processor) Execute(db *DB) *DB { + // call scopes + for len(db.Statement.scopes) > 0 { + db = db.executeScopes() + } + + var ( + curTime = time.Now() + stmt = db.Statement + resetBuildClauses bool + ) + + if len(stmt.BuildClauses) == 0 { + stmt.BuildClauses = p.Clauses + resetBuildClauses = true + } + + if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) + } + + // assign model values + if stmt.Model == nil { + stmt.Model = stmt.Dest + } else if stmt.Dest == nil { + stmt.Dest = stmt.Model + } + + // parse model values + if stmt.Model != nil { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { + if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil { + db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) + } else { + db.AddError(err) + } + } + } + + // assign stmt.ReflectValue + if stmt.Dest != nil { + stmt.ReflectValue = reflect.ValueOf(stmt.Dest) + for stmt.ReflectValue.Kind() == reflect.Ptr { + if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { + stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) + } + + stmt.ReflectValue = stmt.ReflectValue.Elem() + } + if !stmt.ReflectValue.IsValid() { + db.AddError(ErrInvalidValue) + } + } + + for _, f := range p.fns { + f(db) + } + + if stmt.SQL.Len() > 0 { + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { + sql, vars := stmt.SQL.String(), stmt.Vars + if filter, ok := db.Logger.(ParamsFilter); ok { + sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...) + } + return db.Dialector.Explain(sql, vars...), db.RowsAffected + }, db.Error) + } + + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + } + + if resetBuildClauses { + stmt.BuildClauses = nil + } + + return db +} + +func (p *processor) Get(name string) func(*DB) { + for i := len(p.callbacks) - 1; i >= 0; i-- { + if v := p.callbacks[i]; v.name == name && !v.remove { + return v.handler + } + } + return nil +} + +func (p *processor) Before(name string) *callback { + return &callback{before: name, processor: p} +} + +func (p *processor) After(name string) *callback { + return &callback{after: name, processor: p} +} + +func (p *processor) Match(fc func(*DB) bool) *callback { + return &callback{match: fc, processor: p} +} + +func (p *processor) Register(name string, fn func(*DB)) error { + return (&callback{processor: p}).Register(name, fn) +} + +func (p *processor) Remove(name string) error { + return (&callback{processor: p}).Remove(name) +} + +func (p *processor) Replace(name string, fn func(*DB)) error { + return (&callback{processor: p}).Replace(name, fn) +} + +func (p *processor) compile() (err error) { + var callbacks []*callback + for _, callback := range p.callbacks { + if callback.match == nil || callback.match(p.db) { + callbacks = append(callbacks, callback) + } + } + p.callbacks = callbacks + + if p.fns, err = sortCallbacks(p.callbacks); err != nil { + p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) + } + return +} + +func (c *callback) Before(name string) *callback { + c.before = name + return c +} + +func (c *callback) After(name string) *callback { + c.after = name + return c +} + +func (c *callback) Register(name string, fn func(*DB)) error { + c.name = name + c.handler = fn + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +func (c *callback) Remove(name string) error { + c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) + c.name = name + c.remove = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +func (c *callback) Replace(name string, fn func(*DB)) error { + c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) + c.name = name + c.handler = fn + c.replace = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +// getRIndex get right index from string slice +func getRIndex(strs []string, str string) int { + for i := len(strs) - 1; i >= 0; i-- { + if strs[i] == str { + return i + } + } + return -1 +} + +func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { + var ( + names, sorted []string + sortCallback func(*callback) error + ) + sort.SliceStable(cs, func(i, j int) bool { + if cs[j].before == "*" && cs[i].before != "*" { + return true + } + if cs[j].after == "*" && cs[i].after != "*" { + return true + } + return false + }) + + for _, c := range cs { + // show warning message the callback name already exists + if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) + } + names = append(names, c.name) + } + + sortCallback = func(c *callback) error { + if c.before != "" { // if defined before callback + if c.before == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append([]string{c.name}, sorted...) + } + } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if before callback already sorted, append current callback just after it + sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) + } else if curIdx > sortedIdx { + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) + } + } else if idx := getRIndex(names, c.before); idx != -1 { + // if before callback exists + cs[idx].after = c.name + } + } + + if c.after != "" { // if defined after callback + if c.after == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append(sorted, c.name) + } + } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if after callback sorted, append current callback to last + sorted = append(sorted, c.name) + } else if curIdx < sortedIdx { + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) + } + } else if idx := getRIndex(names, c.after); idx != -1 { + // if after callback exists but haven't sorted + // set after callback's before callback to current callback + after := cs[idx] + + if after.before == "" { + after.before = c.name + } + + if err := sortCallback(after); err != nil { + return err + } + + if err := sortCallback(c); err != nil { + return err + } + } + } + + // if current callback haven't been sorted, append it to last + if getRIndex(sorted, c.name) == -1 { + sorted = append(sorted, c.name) + } + + return nil + } + + for _, c := range cs { + if err = sortCallback(c); err != nil { + return + } + } + + for _, name := range sorted { + if idx := getRIndex(names, name); !cs[idx].remove { + fns = append(fns, cs[idx].handler) + } + } + + return +} diff --git a/callbacks/associations.go b/callbacks/associations.go new file mode 100644 index 00000000..f3cd464a --- /dev/null +++ b/callbacks/associations.go @@ -0,0 +1,453 @@ +package callbacks + +import ( + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func SaveBeforeAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) + + // Save Belongs To associations + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + setupReferences := func(obj reflect.Value, elem reflect.Value) { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv)) + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + dest[ref.ForeignKey.DBName] = pv + if _, ok := dest[rel.Name]; ok { + dest[rel.Name] = elem.Interface() + } + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + rValLen = db.Statement.ReflectValue.Len() + objs = make([]reflect.Value, 0, rValLen) + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} + for i := 0; i < rValLen; i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() != reflect.Struct { + break + } + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value + if !isPtr { + rv = rv.Addr() + } + objs = append(objs, obj) + elems = reflect.Append(elems, rv) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + distinctElems = reflect.Append(distinctElems, rv) + } + } + } + + if elems.Len() > 0 { + if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil { + for i := 0; i < elems.Len(); i++ { + setupReferences(objs[i], elems.Index(i)) + } + } + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { + setupReferences(db.Statement.ReflectValue, rv) + } + } + } + } + } + } +} + +func SaveAfterAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) + + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } + } + } + + if elems.Len() > 0 { + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) + if f.Kind() != reflect.Ptr { + f = f.Addr() + } + + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)) + } + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) + } + } + } + + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)) + } + } + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + if elems.Len() > 0 { + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) + } + } + + // Save Many2Many associations + for _, rel := range db.Statement.Schema.Relationships.Many2Many { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) + objs := []reflect.Value{} + + appendToJoins := func(obj reflect.Value, elem reflect.Value) { + joinValue := reflect.New(rel.JoinTable.ModelType) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)) + } else { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) + } + } + joins = reflect.Append(joins, joinValue) + } + + identityMap := map[string]bool{} + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + if !isPtr { + elem = elem.Addr() + } + objs = append(objs, v) + elems = reflect.Append(elems, elem) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + distinctElems = reflect.Append(distinctElems, elem) + } + + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + // optimize elems of reflect value length + if elemLen := elems.Len(); elemLen > 0 { + if v, ok := selectColumns[rel.Name+".*"]; !ok || v { + saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) + } + + for i := 0; i < elemLen; i++ { + appendToJoins(objs[i], elems.Index(i)) + } + } + + if joins.Len() > 0 { + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }).Create(joins.Interface()).Error) + } + } + } + } +} + +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { + if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { + onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) + for _, dbName := range s.PrimaryFieldDBNames { + onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName}) + } + + onConflict.UpdateAll = stmt.DB.FullSaveAssociations + if !onConflict.UpdateAll { + onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns) + } + } else { + onConflict.DoNothing = true + } + + return +} + +func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + // stop save association loop + if checkAssociationsSaved(db, rValues) { + return nil + } + + var ( + selects, omits []string + onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns) + refName = rel.Name + "." + values = rValues.Interface() + ) + + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, refName) { + columnName = strings.TrimPrefix(name, refName) + } + + if columnName != "" { + if ok { + selects = append(selects, columnName) + } else { + omits = append(omits, columnName) + } + } + } + + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ + FullSaveAssociations: db.FullSaveAssociations, + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }) + + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + + if tx.Statement.FullSaveAssociations { + tx = tx.Set("gorm:update_track_time", true) + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } else if restricted && len(omits) == 0 { + tx = tx.Omit(clause.Associations) + } + + if len(omits) > 0 { + tx = tx.Omit(omits...) + } + + return db.AddError(tx.Create(values).Error) +} + +// check association values has been saved +// if values kind is Struct, check it has been saved +// if values kind is Slice/Array, check all items have been saved +var visitMapStoreKey = "gorm:saved_association_map" + +func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool { + if visit, ok := db.Get(visitMapStoreKey); ok { + if v, ok := visit.(*visitMap); ok { + if loadOrStoreVisitMap(v, values) { + return true + } + } + } else { + vistMap := make(visitMap) + loadOrStoreVisitMap(&vistMap, values) + db.Set(visitMapStoreKey, &vistMap) + } + + return false +} diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go new file mode 100644 index 00000000..17953e7b --- /dev/null +++ b/callbacks/callbacks.go @@ -0,0 +1,85 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +var ( + createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} + queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} + updateClauses = []string{"UPDATE", "SET", "WHERE"} + deleteClauses = []string{"DELETE", "FROM", "WHERE"} +) + +type Config struct { + // LastInsertIDReversed 在某些情况下,MySQL 返回的自增 ID 可能会被反转,即高位和低位互换。 + // 例如,当使用某些 MySQL 存储引擎(如 MyISAM)时,可能会发生自增 ID 反转的情况。 + LastInsertIDReversed bool + CreateClauses []string + QueryClauses []string + UpdateClauses []string + DeleteClauses []string +} + +func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { + enableTransaction := func(db *gorm.DB) bool { + return !db.SkipDefaultTransaction + } + + if len(config.CreateClauses) == 0 { + config.CreateClauses = createClauses + } + if len(config.QueryClauses) == 0 { + config.QueryClauses = queryClauses + } + if len(config.DeleteClauses) == 0 { + config.DeleteClauses = deleteClauses + } + if len(config.UpdateClauses) == 0 { + config.UpdateClauses = updateClauses + } + + createCallback := db.Callback().Create() + createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + createCallback.Register("gorm:before_create", BeforeCreate) + createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) + createCallback.Register("gorm:create", Create(config)) + createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) + createCallback.Register("gorm:after_create", AfterCreate) + createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + createCallback.Clauses = config.CreateClauses + + queryCallback := db.Callback().Query() + queryCallback.Register("gorm:query", Query) + queryCallback.Register("gorm:preload", Preload) + queryCallback.Register("gorm:after_query", AfterQuery) + queryCallback.Clauses = config.QueryClauses + + deleteCallback := db.Callback().Delete() + deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) + deleteCallback.Register("gorm:delete", Delete(config)) + deleteCallback.Register("gorm:after_delete", AfterDelete) + deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + deleteCallback.Clauses = config.DeleteClauses + + updateCallback := db.Callback().Update() + updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) + updateCallback.Register("gorm:before_update", BeforeUpdate) + updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) + updateCallback.Register("gorm:update", Update(config)) + updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) + updateCallback.Register("gorm:after_update", AfterUpdate) + updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + updateCallback.Clauses = config.UpdateClauses + + rowCallback := db.Callback().Row() + rowCallback.Register("gorm:row", RowQuery) + rowCallback.Clauses = config.QueryClauses + + rawCallback := db.Callback().Raw() + rawCallback.Register("gorm:raw", RawExec) + rawCallback.Clauses = config.QueryClauses +} diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go new file mode 100644 index 00000000..fb900037 --- /dev/null +++ b/callbacks/callmethod.go @@ -0,0 +1,32 @@ +package callbacks + +import ( + "reflect" + + "gorm.io/gorm" +) + +func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { + tx := db.Session(&gorm.Session{NewDB: true}) + if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + db.Statement.CurDestIndex = 0 + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { + fc(value.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + return + } + db.Statement.CurDestIndex++ + } + case reflect.Struct: + if db.Statement.ReflectValue.CanAddr() { + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + } + } + } +} diff --git a/callbacks/create.go b/callbacks/create.go new file mode 100644 index 00000000..f0b78139 --- /dev/null +++ b/callbacks/create.go @@ -0,0 +1,345 @@ +package callbacks + +import ( + "fmt" + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// BeforeCreate before create hooks +func BeforeCreate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.BeforeSave { + if i, ok := value.(BeforeSaveInterface); ok { + called = true + db.AddError(i.BeforeSave(tx)) + } + } + + if db.Statement.Schema.BeforeCreate { + if i, ok := value.(BeforeCreateInterface); ok { + called = true + db.AddError(i.BeforeCreate(tx)) + } + } + return called + }) + } +} + +// Create create hook +func Create(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.CreateClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return + } + + if db.Statement.Schema != nil { + if !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { + if _, ok := db.Statement.Clauses["RETURNING"]; !ok { + fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) + for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + } + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) + } + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Insert{}) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build(db.Statement.BuildClauses...) + } + + isDryRun := !db.DryRun && db.Error == nil + if !isDryRun { + return + } + + ok, mode := hasReturning(db, supportReturning) + if ok { + if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + mode |= gorm.ScanOnConflictDoNothing + } + } + + rows, err := db.Statement.ConnPool.QueryContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if db.AddError(err) == nil { + defer func() { + db.AddError(rows.Close()) + }() + gorm.Scan(rows, db, mode) + } + + return + } + + result, err := db.Statement.ConnPool.ExecContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if err != nil { + db.AddError(err) + return + } + + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && + db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + if !insertOk { + db.AddError(err) + return + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) + if isZero { + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } + case reflect.Struct: + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + if isZero { + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) + } + } + } + } +} + +// AfterCreate after create hooks +func AfterCreate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterCreate { + if i, ok := value.(AfterCreateInterface); ok { + called = true + db.AddError(i.AfterCreate(tx)) + } + } + + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } + return called + }) + } +} + +// ConvertToCreateValues convert to create values +func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { + curTime := stmt.DB.NowFunc() + + switch value := stmt.Dest.(type) { + case map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, value) + case *map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, *value) + case []map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, value) + case *[]map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, *value) + default: + var ( + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + _, updateTrackTime = stmt.Get("gorm:update_track_time") + isZero bool + ) + stmt.Settings.Delete("gorm:update_track_time") + + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} + + for _, db := range stmt.Schema.DBNames { + if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { + if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { + values.Columns = append(values.Columns, clause.Column{Name: db}) + } + } + } + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + rValLen := stmt.ReflectValue.Len() + if rValLen == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + + stmt.SQL.Grow(rValLen * 18) + stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) + values.Values = make([][]interface{}, rValLen) + + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + for i := 0; i < rValLen; i++ { + rv := reflect.Indirect(stmt.ReflectValue.Index(i)) + if !rv.IsValid() { + stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) + return + } + + values.Values[i] = make([]interface{}, len(values.Columns)) + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { + if field.DefaultValueInterface != nil { + values.Values[i][idx] = field.DefaultValueInterface + stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + stmt.AddError(field.Set(stmt.Context, rv, curTime)) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) + } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + stmt.AddError(field.Set(stmt.Context, rv, curTime)) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) + } + } + + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { + if len(defaultValueFieldsHavingValue[field]) == 0 { + defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) + } + defaultValueFieldsHavingValue[field][i] = rvOfvalue + } + } + } + } + + for field, vs := range defaultValueFieldsHavingValue { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } + } + } + case reflect.Struct: + values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { + if field.DefaultValueInterface != nil { + values.Values[0][idx] = field.DefaultValueInterface + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) + } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) + } + } + + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + values.Values[0] = append(values.Values[0], rvOfvalue) + } + } + } + default: + stmt.AddError(gorm.ErrInvalidData) + } + } + + if c, ok := stmt.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { + if stmt.Schema != nil && len(values.Columns) >= 1 { + selectColumns, restricted := stmt.SelectAndOmitColumns(true, true) + + columns := make([]string, 0, len(values.Columns)-1) + for _, column := range values.Columns { + if field := stmt.Schema.LookUpField(column.Name); field != nil { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || + strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 { + if field.AutoUpdateTime > 0 { + assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} + switch field.AutoUpdateTime { + case schema.UnixNanosecond: + assignment.Value = curTime.UnixNano() + case schema.UnixMillisecond: + assignment.Value = curTime.UnixNano() / 1e6 + case schema.UnixSecond: + assignment.Value = curTime.Unix() + } + + onConflict.DoUpdates = append(onConflict.DoUpdates, assignment) + } else { + columns = append(columns, column.Name) + } + } + } + } + } + + onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) + if len(onConflict.DoUpdates) == 0 { + onConflict.DoNothing = true + } + + // use primary fields as default OnConflict columns + if len(onConflict.Columns) == 0 { + for _, field := range stmt.Schema.PrimaryFields { + onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName}) + } + } + stmt.AddClause(onConflict) + } + } + } + + return values +} diff --git a/callbacks/delete.go b/callbacks/delete.go new file mode 100644 index 00000000..84f446a3 --- /dev/null +++ b/callbacks/delete.go @@ -0,0 +1,185 @@ +package callbacks + +import ( + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func BeforeDelete(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(BeforeDeleteInterface); ok { + db.AddError(i.BeforeDelete(tx)) + return true + } + + return false + }) + } +} + +func DeleteBeforeAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + if !restricted { + return + } + + for column, v := range selectColumns { + if !v { + continue + } + + rel, ok := db.Statement.Schema.Relationships.Relations[column] + if !ok { + continue + } + + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) + withoutConditions := false + if db.Statement.Unscoped { + tx = tx.Unscoped() + } + + if len(db.Statement.Selects) > 0 { + selects := make([]string, 0, len(db.Statement.Selects)) + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { + selects = append(selects, strings.TrimPrefix(s, columnPrefix)) + } + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds = make([]clause.Expression, 0, len(rel.References)) + foreignFields = make([]*schema.Field, 0, len(rel.References)) + relForeignKeys = make([]string, 0, len(rel.References)) + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + } + + } +} + +func Delete(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.DeleteClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.SQL.Grow(100) + db.Statement.AddClauseIfNotExists(clause.Delete{}) + + if db.Statement.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + db.Statement.AddClauseIfNotExists(clause.From{}) + + db.Statement.Build(db.Statement.BuildClauses...) + } + + checkMissingWhereConditions(db) + + if !db.DryRun && db.Error == nil { + ok, mode := hasReturning(db, supportReturning) + if !ok { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.AddError(err) == nil { + db.RowsAffected, _ = result.RowsAffected() + } + + return + } + + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + db.AddError(rows.Close()) + } + } + } +} + +func AfterDelete(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(AfterDeleteInterface); ok { + db.AddError(i.AfterDelete(tx)) + return true + } + return false + }) + } +} diff --git a/callbacks/helper.go b/callbacks/helper.go new file mode 100644 index 00000000..ae9fd8c5 --- /dev/null +++ b/callbacks/helper.go @@ -0,0 +1,152 @@ +package callbacks + +import ( + "reflect" + "sort" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// ConvertMapToValuesForCreate convert map to values +func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { + values.Columns = make([]clause.Column, 0, len(mapValue)) + selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) + + keys := make([]string, 0, len(mapValue)) + for k := range mapValue { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + value := mapValue[k] + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + values.Columns = append(values.Columns, clause.Column{Name: k}) + if len(values.Values) == 0 { + values.Values = [][]interface{}{{}} + } + + values.Values[0] = append(values.Values[0], value) + } + } + return +} + +// ConvertSliceOfMapToValuesForCreate convert slice of map to values +func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { + columns := make([]string, 0, len(mapValues)) + + // when the length of mapValues is zero,return directly here + // no need to call stmt.SelectAndOmitColumns method + if len(mapValues) == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + + var ( + result = make(map[string][]interface{}, len(mapValues)) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + ) + + for idx, mapValue := range mapValues { + for k, v := range mapValue { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + } + + if _, ok := result[k]; !ok { + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + result[k] = make([]interface{}, len(mapValues)) + columns = append(columns, k) + } else { + continue + } + } + + result[k][idx] = v + } + } + + sort.Strings(columns) + values.Values = make([][]interface{}, len(mapValues)) + values.Columns = make([]clause.Column, len(columns)) + for idx, column := range columns { + values.Columns[idx] = clause.Column{Name: column} + + for i, v := range result[column] { + if len(values.Values[i]) == 0 { + values.Values[i] = make([]interface{}, len(columns)) + } + + values.Values[i][idx] = v + } + } + return +} + +func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { + if supportReturning { + if c, ok := tx.Statement.Clauses["RETURNING"]; ok { + returning, _ := c.Expression.(clause.Returning) + if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") { + return true, 0 + } + return true, gorm.ScanUpdate + } + } + return false, 0 +} + +func checkMissingWhereConditions(db *gorm.DB) { + if !db.AllowGlobalUpdate && db.Error == nil { + where, withCondition := db.Statement.Clauses["WHERE"] + if withCondition { + if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { + whereClause, _ := where.Expression.(clause.Where) + withCondition = len(whereClause.Exprs) > 1 + } + } + if !withCondition { + db.AddError(gorm.ErrMissingWhereClause) + } + return + } +} + +type visitMap = map[reflect.Value]bool + +// Check if circular values, return true if loaded +func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + switch v.Kind() { + case reflect.Slice, reflect.Array: + loaded = true + for i := 0; i < v.Len(); i++ { + if !loadOrStoreVisitMap(visitMap, v.Index(i)) { + loaded = false + } + } + case reflect.Struct, reflect.Interface: + if v.CanAddr() { + p := v.Addr() + if _, ok := (*visitMap)[p]; ok { + return true + } + (*visitMap)[p] = true + } + } + + return +} diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go new file mode 100644 index 00000000..2302470f --- /dev/null +++ b/callbacks/interfaces.go @@ -0,0 +1,39 @@ +package callbacks + +import "gorm.io/gorm" + +type BeforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} + +type AfterCreateInterface interface { + AfterCreate(*gorm.DB) error +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*gorm.DB) error +} + +type AfterUpdateInterface interface { + AfterUpdate(*gorm.DB) error +} + +type BeforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type AfterSaveInterface interface { + AfterSave(*gorm.DB) error +} + +type BeforeDeleteInterface interface { + BeforeDelete(*gorm.DB) error +} + +type AfterDeleteInterface interface { + AfterDelete(*gorm.DB) error +} + +type AfterFindInterface interface { + AfterFind(*gorm.DB) error +} diff --git a/callbacks/preload.go b/callbacks/preload.go new file mode 100644 index 00000000..15669c84 --- /dev/null +++ b/callbacks/preload.go @@ -0,0 +1,266 @@ +package callbacks + +import ( + "fmt" + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// parsePreloadMap extracts nested preloads. e.g. +// +// // schema has a "k0" relation and a "k7.k8" embedded relation +// parsePreloadMap(schema, map[string][]interface{}{ +// clause.Associations: {"arg1"}, +// "k1": {"arg2"}, +// "k2.k3": {"arg3"}, +// "k4.k5.k6": {"arg4"}, +// }) +// // preloadMap is +// map[string]map[string][]interface{}{ +// "k0": {}, +// "k7": { +// "k8": {}, +// }, +// "k1": {}, +// "k2": { +// "k3": {"arg3"}, +// }, +// "k4": { +// "k5.k6": {"arg4"}, +// }, +// } +func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { + preloadMap := map[string]map[string][]interface{}{} + setPreloadMap := func(name, value string, args []interface{}) { + if _, ok := preloadMap[name]; !ok { + preloadMap[name] = map[string][]interface{}{} + } + if value != "" { + preloadMap[name][value] = args + } + } + + for name, args := range preloads { + preloadFields := strings.Split(name, ".") + value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".") + if preloadFields[0] == clause.Associations { + for _, relation := range s.Relationships.Relations { + if relation.Schema == s { + setPreloadMap(relation.Name, value, args) + } + } + + for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations { + for _, value := range embeddedValues(embeddedRelations) { + setPreloadMap(embedded, value, args) + } + } + } else { + setPreloadMap(preloadFields[0], value, args) + } + } + return preloadMap +} + +func embeddedValues(embeddedRelations *schema.Relationships) []string { + if embeddedRelations == nil { + return nil + } + names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) + for _, relation := range embeddedRelations.Relations { + // skip first struct name + names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) + } + for _, relations := range embeddedRelations.EmbeddedRelations { + names = append(names, embeddedValues(relations)...) + } + return names +} + +func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { + if relationships == nil { + return nil + } + preloadMap := parsePreloadMap(s, preloads) + for name := range preloadMap { + if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { + if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { + return err + } + } else if rel := relationships.Relations[name]; rel != nil { + if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { + return err + } + } else { + return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) + } + } + return nil +} + +func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { + var ( + reflectValue = tx.Statement.ReflectValue + relForeignKeys []string + relForeignFields []*schema.Field + foreignFields []*schema.Field + foreignValues [][]interface{} + identityMap = map[string][]reflect.Value{} + inlineConds []interface{} + ) + + if rel.JoinTable != nil { + var ( + joinForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinRelForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinForeignKeys = make([]string, 0, len(rel.References)) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) + joinForeignFields = append(joinForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + } + } + + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) + if len(joinForeignValues) == 0 { + return nil + } + + joinResults := rel.JoinTable.MakeSlice().Elem() + column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { + return err + } + + // convert join identity map to relation identity map + fieldValues := make([]interface{}, len(joinForeignFields)) + joinFieldValues := make([]interface{}, len(joinRelForeignFields)) + for i := 0; i < joinResults.Len(); i++ { + joinIndexValue := joinResults.Index(i) + for idx, field := range joinForeignFields { + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) + } + + for idx, field := range joinRelForeignFields { + joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) + } + + if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { + joinKey := utils.ToStringKey(joinFieldValues...) + identityMap[joinKey] = append(identityMap[joinKey], results...) + } + } + + _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + relForeignFields = append(relForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) + if len(foreignValues) == 0 { + return nil + } + } + + // nested preload + for p, pvs := range preloads { + tx = tx.Preload(p, pvs...) + } + + reflectResults := rel.FieldSchema.MakeSlice().Elem() + column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) + + if len(values) != 0 { + for _, cond := range conds { + if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { + tx = fc(tx) + } else { + inlineConds = append(inlineConds, cond) + } + } + + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { + return err + } + } + + fieldValues := make([]interface{}, len(relForeignFields)) + + // clean up old values before preloading + switch reflectValue.Kind() { + case reflect.Struct: + switch rel.Type { + case schema.HasMany, schema.Many2Many: + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) + default: + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())) + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + switch rel.Type { + case schema.HasMany, schema.Many2Many: + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) + default: + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())) + } + } + } + + for i := 0; i < reflectResults.Len(); i++ { + elem := reflectResults.Index(i) + for idx, field := range relForeignFields { + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) + } + + datas, ok := identityMap[utils.ToStringKey(fieldValues...)] + if !ok { + return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) + } + + for _, data := range datas { + reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } + + reflectFieldValue = reflect.Indirect(reflectFieldValue) + switch reflectFieldValue.Kind() { + case reflect.Struct: + tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface())) + case reflect.Slice, reflect.Array: + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())) + } else { + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())) + } + } + } + } + + return tx.Error +} diff --git a/callbacks/query.go b/callbacks/query.go new file mode 100644 index 00000000..e89dd199 --- /dev/null +++ b/callbacks/query.go @@ -0,0 +1,316 @@ +package callbacks + +import ( + "fmt" + "reflect" + "sort" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func Query(db *gorm.DB) { + if db.Error == nil { + BuildQuerySQL(db) + + if !db.DryRun && db.Error == nil { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + defer func() { + db.AddError(rows.Close()) + }() + gorm.Scan(rows, db, 0) + } + } +} + +func BuildQuerySQL(db *gorm.DB) { + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.SQL.Grow(100) + clauseSelect := clause.Select{Distinct: db.Statement.Distinct} + + if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { + var conds []clause.Expression + for _, primaryField := range db.Statement.Schema.PrimaryFields { + if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero { + conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) + } + } + + if len(conds) > 0 { + db.Statement.AddClause(clause.Where{Exprs: conds}) + } + } + + if len(db.Statement.Selects) > 0 { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) + for idx, name := range db.Statement.Selects { + if db.Statement.Schema == nil { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } else if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} + } else { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } + } + } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { + selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) + clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) + for _, dbName := range db.Statement.Schema.DBNames { + if v, ok := selectColumns[dbName]; (ok && v) || !ok { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName}) + } + } + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { + queryFields := db.QueryFields + if !queryFields { + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + } + } + + if queryFields { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } + } + } + } + + // inline joins + fromClause := clause.From{} + if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + fromClause = v + } + + if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { + if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } + } + + specifiedRelationsName := make(map[string]interface{}) + for _, join := range db.Statement.Joins { + if db.Statement.Schema != nil { + var isRelations bool // is relations or raw sql + var relations []*schema.Relationship + relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] + if ok { + isRelations = true + relations = append(relations, relation) + } else { + // handle nested join like "Manager.Company" + nestedJoinNames := strings.Split(join.Name, ".") + if len(nestedJoinNames) > 1 { + isNestedJoin := true + gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + currentRelations := db.Statement.Schema.Relationships.Relations + for _, relname := range nestedJoinNames { + // incomplete match, only treated as raw sql + if relation, ok = currentRelations[relname]; ok { + gussNestedRelations = append(gussNestedRelations, relation) + currentRelations = relation.FieldSchema.Relationships.Relations + } else { + isNestedJoin = false + break + } + } + + if isNestedJoin { + isRelations = true + relations = gussNestedRelations + } + } + } + + if isRelations { + genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { + tableAliasName := relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + } + + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } + + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) + for _, s := range relation.FieldSchema.DBNames { + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: utils.NestedRelationName(tableAliasName, s), + }) + } + } + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } + } + } + } + + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) + } + + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } + } + + return clause.Join{ + Type: joinType, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + } + } + + parentTableName := clause.CurrentTable + for _, rel := range relations { + // joins table alias like "Manager, Company, Manager__Company" + nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) + if _, ok := specifiedRelationsName[nestedAlias]; !ok { + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) + specifiedRelationsName[nestedAlias] = nil + } + + if parentTableName != clause.CurrentTable { + parentTableName = utils.NestedRelationName(parentTableName, rel.Name) + } else { + parentTableName = rel.Name + } + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } + } + + db.Statement.AddClause(fromClause) + db.Statement.Joins = nil + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) + } + + db.Statement.AddClauseIfNotExists(clauseSelect) + + db.Statement.Build(db.Statement.BuildClauses...) + } +} + +func Preload(db *gorm.DB) { + if db.Error == nil && len(db.Statement.Preloads) > 0 { + if db.Statement.Schema == nil { + db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired)) + return + } + + preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) + preloadNames := make([]string, 0, len(preloadMap)) + for key := range preloadMap { + preloadNames = append(preloadNames, key) + } + sort.Strings(preloadNames) + + preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + preloadDB.Statement.Settings.Store(k, v) + return true + }) + + if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + return + } + preloadDB.Statement.ReflectValue = db.Statement.ReflectValue + preloadDB.Statement.Unscoped = db.Statement.Unscoped + + for _, name := range preloadNames { + if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { + db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) + } else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { + db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) + } else { + db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) + } + } + } +} + +func AfterQuery(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(AfterFindInterface); ok { + db.AddError(i.AfterFind(tx)) + return true + } + return false + }) + } +} diff --git a/callbacks/raw.go b/callbacks/raw.go new file mode 100644 index 00000000..013e638c --- /dev/null +++ b/callbacks/raw.go @@ -0,0 +1,17 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func RawExec(db *gorm.DB) { + if db.Error == nil && !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + + db.RowsAffected, _ = result.RowsAffected() + } +} diff --git a/callbacks/row.go b/callbacks/row.go new file mode 100644 index 00000000..beaa189e --- /dev/null +++ b/callbacks/row.go @@ -0,0 +1,23 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func RowQuery(db *gorm.DB) { + if db.Error == nil { + BuildQuerySQL(db) + if db.DryRun || db.Error != nil { + return + } + + if isRows, ok := db.Get("rows"); ok && isRows.(bool) { + db.Statement.Settings.Delete("rows") + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } + + db.RowsAffected = -1 + } +} diff --git a/callbacks/transaction.go b/callbacks/transaction.go new file mode 100644 index 00000000..50887ccc --- /dev/null +++ b/callbacks/transaction.go @@ -0,0 +1,32 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func BeginTransaction(db *gorm.DB) { + if !db.Config.SkipDefaultTransaction && db.Error == nil { + if tx := db.Begin(); tx.Error == nil { + db.Statement.ConnPool = tx.Statement.ConnPool + db.InstanceSet("gorm:started_transaction", true) + } else if tx.Error == gorm.ErrInvalidTransaction { + tx.Error = nil + } else { + db.Error = tx.Error + } + } +} + +func CommitOrRollbackTransaction(db *gorm.DB) { + if !db.Config.SkipDefaultTransaction { + if _, ok := db.InstanceGet("gorm:started_transaction"); ok { + if db.Error != nil { + db.Rollback() + } else { + db.Commit() + } + + db.Statement.ConnPool = db.ConnPool + } + } +} diff --git a/callbacks/update.go b/callbacks/update.go new file mode 100644 index 00000000..4eb75788 --- /dev/null +++ b/callbacks/update.go @@ -0,0 +1,303 @@ +package callbacks + +import ( + "reflect" + "sort" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func SetupUpdateReflectValue(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if _, ok := dest[rel.Name]; ok { + db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) + } + } + } + } + } +} + +// BeforeUpdate before update hooks +func BeforeUpdate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.BeforeSave { + if i, ok := value.(BeforeSaveInterface); ok { + called = true + db.AddError(i.BeforeSave(tx)) + } + } + + if db.Statement.Schema.BeforeUpdate { + if i, ok := value.(BeforeUpdateInterface); ok { + called = true + db.AddError(i.BeforeUpdate(tx)) + } + } + + return called + }) + } +} + +// Update update hook +func Update(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Update{}) + if _, ok := db.Statement.Clauses["SET"]; !ok { + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } + } + + db.Statement.Build(db.Statement.BuildClauses...) + } + + checkMissingWhereConditions(db) + + if !db.DryRun && db.Error == nil { + if ok, mode := hasReturning(db, supportReturning); ok { + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + dest := db.Statement.Dest + db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() + gorm.Scan(rows, db, mode) + db.Statement.Dest = dest + db.AddError(rows.Close()) + } + } else { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if db.AddError(err) == nil { + db.RowsAffected, _ = result.RowsAffected() + } + } + } + } +} + +// AfterUpdate after update hooks +func AfterUpdate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterUpdate { + if i, ok := value.(AfterUpdateInterface); ok { + called = true + db.AddError(i.AfterUpdate(tx)) + } + } + + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } + + return called + }) + } +} + +// ConvertToAssignments convert to update assignments +func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { + var ( + selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) + assignValue func(field *schema.Field, value interface{}) + ) + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + assignValue = func(field *schema.Field, value interface{}) { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + } + } + } + case reflect.Struct: + assignValue = func(field *schema.Field, value interface{}) { + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.Context, stmt.ReflectValue, value) + } + } + default: + assignValue = func(field *schema.Field, value interface{}) { + } + } + + updatingValue := reflect.ValueOf(stmt.Dest) + for updatingValue.Kind() == reflect.Ptr { + updatingValue = updatingValue.Elem() + } + + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if size := stmt.ReflectValue.Len(); size > 0 { + var isZero bool + for i := 0; i < size; i++ { + for _, field := range stmt.Schema.PrimaryFields { + _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) + if !isZero { + break + } + } + } + + if !isZero { + _, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues) + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + case reflect.Struct: + for _, field := range stmt.Schema.PrimaryFields { + if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + } + + switch value := updatingValue.Interface().(type) { + case map[string]interface{}: + set = make([]clause.Assignment, 0, len(value)) + + keys := make([]string, 0, len(value)) + for k := range value { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + kv := value[k] + if _, ok := kv.(*gorm.DB); ok { + kv = []interface{}{kv} + } + + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + if field.DBName != "" { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) + assignValue(field, value[k]) + } + } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { + assignValue(field, value[k]) + } + continue + } + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) + } + } + + if !stmt.SkipHooks && stmt.Schema != nil { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) + if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { + if v, ok := selectColumns[field.DBName]; (ok && v) || !ok { + now := stmt.DB.NowFunc() + assignValue(field, now) + + if field.AutoUpdateTime == schema.UnixNanosecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.AutoUpdateTime == schema.UnixMillisecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) + } else if field.AutoUpdateTime == schema.UnixSecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + } + } + } + } + } + default: + updatingSchema := stmt.Schema + var isDiffSchema bool + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + // different schema + updatingStmt := &gorm.Statement{DB: stmt.DB} + if err := updatingStmt.Parse(stmt.Dest); err == nil { + updatingSchema = updatingStmt.Schema + isDiffSchema = true + } + } + + switch updatingValue.Kind() { + case reflect.Struct: + set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) + for _, dbName := range stmt.Schema.DBNames { + if field := updatingSchema.LookUpField(dbName); field != nil { + if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { + value, isZero := field.ValueOf(stmt.Context, updatingValue) + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 + } else if field.AutoUpdateTime == schema.UnixSecond { + value = stmt.DB.NowFunc().Unix() + } else { + value = stmt.DB.NowFunc() + } + isZero = false + } + + if (ok || !isZero) && field.Updatable { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + assignField := field + if isDiffSchema { + if originField := stmt.Schema.LookUpField(dbName); originField != nil { + assignField = originField + } + } + assignValue(assignField, value) + } + } + } else { + if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + } + default: + stmt.AddError(gorm.ErrInvalidData) + } + } + + return +} diff --git a/callbacks/visit_map_test.go b/callbacks/visit_map_test.go new file mode 100644 index 00000000..b1fb86db --- /dev/null +++ b/callbacks/visit_map_test.go @@ -0,0 +1,36 @@ +package callbacks + +import ( + "reflect" + "testing" +) + +func TestLoadOrStoreVisitMap(t *testing.T) { + var vm visitMap + var loaded bool + type testM struct { + Name string + } + + t1 := testM{Name: "t1"} + t2 := testM{Name: "t2"} + t3 := testM{Name: "t3"} + + vm = make(visitMap) + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { + t.Fatalf("loaded should be true") + } + + // t1 already exist but t2 not + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { + t.Fatalf("loaded should be true") + } +} diff --git a/callbacks_test.go b/callbacks_test.go deleted file mode 100644 index a58913d7..00000000 --- a/callbacks_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package gorm_test - -import ( - "errors" - - "github.com/jinzhu/gorm" - - "reflect" - "testing" -) - -func (s *Product) BeforeCreate() (err error) { - if s.Code == "Invalid" { - err = errors.New("invalid product") - } - s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 - return -} - -func (s *Product) BeforeUpdate() (err error) { - if s.Code == "dont_update" { - err = errors.New("can't update") - } - s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 - return -} - -func (s *Product) BeforeSave() (err error) { - if s.Code == "dont_save" { - err = errors.New("can't save") - } - s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 - return -} - -func (s *Product) AfterFind() { - s.AfterFindCallTimes = s.AfterFindCallTimes + 1 -} - -func (s *Product) AfterCreate(tx *gorm.DB) { - tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) -} - -func (s *Product) AfterUpdate() { - s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 -} - -func (s *Product) AfterSave() (err error) { - if s.Code == "after_save_error" { - err = errors.New("can't save") - } - s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 - return -} - -func (s *Product) BeforeDelete() (err error) { - if s.Code == "dont_delete" { - err = errors.New("can't delete") - } - s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 - return -} - -func (s *Product) AfterDelete() (err error) { - if s.Code == "after_delete_error" { - err = errors.New("can't delete") - } - s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 - return -} - -func (s *Product) GetCallTimes() []int64 { - return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} -} - -func TestRunCallbacks(t *testing.T) { - p := Product{Code: "unique_code", Price: 100} - DB.Save(&p) - - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { - t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { - t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) - } - - p.Price = 200 - DB.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { - t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - var products []Product - DB.Find(&products, "code = ?", "unique_code") - if products[0].AfterFindCallTimes != 2 { - t.Errorf("AfterFind callbacks should work with slice") - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { - t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) - } - - DB.Delete(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { - t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { - t.Errorf("Can't find a deleted record") - } -} - -func TestCallbacksWithErrors(t *testing.T) { - p := Product{Code: "Invalid", Price: 100} - if DB.Save(&p).Error == nil { - t.Errorf("An error from before create callbacks happened when create with invalid value") - } - - if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { - t.Errorf("Should not save record that have errors") - } - - if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { - t.Errorf("An error from after create callbacks happened when create with invalid value") - } - - p2 := Product{Code: "update_callback", Price: 100} - DB.Save(&p2) - - p2.Code = "dont_update" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before update callbacks happened when update with invalid value") - } - - if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - p2.Code = "dont_save" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before save callbacks happened when update with invalid value") - } - - p3 := Product{Code: "dont_delete", Price: 100} - DB.Save(&p3) - if DB.Delete(&p3).Error == nil { - t.Errorf("An error from before delete callbacks happened when delete") - } - - if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { - t.Errorf("An error from before delete callbacks happened") - } - - p4 := Product{Code: "after_save_error", Price: 100} - DB.Save(&p4) - if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { - t.Errorf("Record should be reverted if get an error in after save callback") - } - - p5 := Product{Code: "after_delete_error", Price: 100} - DB.Save(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record should be found") - } - - DB.Delete(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") - } -} diff --git a/chainable_api.go b/chainable_api.go new file mode 100644 index 00000000..3dc7256e --- /dev/null +++ b/chainable_api.go @@ -0,0 +1,469 @@ +package gorm + +import ( + "fmt" + "regexp" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +// Model specify the model you would like to run db operations +// +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello` +// db.Model(&user).Update("name", "hello") +func (db *DB) Model(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Model = value + return +} + +// Clauses Add clauses +// +// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more +// advanced techniques like specifying lock strength and optimizer hints. See the +// [docs] for more depth. +// +// // add a simple limit clause +// db.Clauses(clause.Limit{Limit: 1}).Find(&User{}) +// // tell the optimizer to use the `idx_user_name` index +// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{}) +// // specify the lock strength to UPDATE +// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users) +// +// [docs]: https://gorm.io/docs/sql_builder.html#Clauses +func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { + tx = db.getInstance() + var whereConds []interface{} + + for _, cond := range conds { + if c, ok := cond.(clause.Interface); ok { + tx.Statement.AddClause(c) + } else if optimizer, ok := cond.(StatementModifier); ok { + optimizer.ModifyStatement(tx.Statement) + } else { + whereConds = append(whereConds, cond) + } + } + + if len(whereConds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)}) + } + return +} + +var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`) + +// Table specify the table you would like to run db operations +// +// // Get a user +// db.Table("users").Take(&result) +func (db *DB) Table(name string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { + tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} + if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 { + if results[1] != "" { + tx.Statement.Table = results[1] + } else { + tx.Statement.Table = results[2] + } + } + } else if tables := strings.Split(name, "."); len(tables) == 2 { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = tables[1] + } else if name != "" { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = name + } else { + tx.Statement.TableExpr = nil + tx.Statement.Table = "" + } + return +} + +// Distinct specify distinct fields that you want querying +// +// // Select distinct names of users +// db.Distinct("name").Find(&results) +// // Select distinct name/age pairs from users +// db.Distinct("name", "age").Find(&results) +func (db *DB) Distinct(args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Distinct = true + if len(args) > 0 { + tx = tx.Select(args[0], args[1:]...) + } + return +} + +// Select specify fields that you want when querying, creating, updating +// +// Use Select when you only want a subset of the fields. By default, GORM will select all fields. +// Select accepts both string arguments and arrays. +// +// // Select name and age of user using multiple arguments +// db.Select("name", "age").Find(&users) +// // Select name and age of user using an array +// db.Select([]string{"name", "age"}).Find(&users) +func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + + switch v := query.(type) { + case []string: + tx.Statement.Selects = v + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + return + } + } + + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } + case string: + if strings.Count(v, "?") >= len(args) && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + } else if strings.Count(v, "@") > 0 && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.NamedExpr{SQL: v, Vars: args}, + }) + } else { + tx.Statement.Selects = []string{v} + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + return + } + } + + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } + } + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + } + + return +} + +// Omit specify fields that you want to ignore when creating, updating and querying +func (db *DB) Omit(columns ...string) (tx *DB) { + tx = db.getInstance() + + if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + } else { + tx.Statement.Omits = columns + } + return +} + +// Where add conditions +// +// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. +// +// // Find the first user with name jinzhu +// db.Where("name = ?", "jinzhu").First(&user) +// // Find the first user with name jinzhu and age 20 +// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) +// // Find the first user with name jinzhu and age not equal to 20 +// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user) +// +// [docs]: https://gorm.io/docs/query.html#Conditions +func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: conds}) + } + return +} + +// Not add NOT conditions +// +// Not works similarly to where, and has the same syntax. +// +// // Find the first user with name not equal to jinzhu +// db.Not("name = ?", "jinzhu").First(&user) +func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) + } + return +} + +// Or add OR conditions +// +// Or is used to chain together queries with an OR. +// +// // Find the first user with name equal to jinzhu or john +// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user) +func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}}) + } + return +} + +// Joins specify Joins conditions +// +// db.Joins("Account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) +func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.LeftJoin, query, args...) +} + +// InnerJoins specify inner joins conditions +// db.InnerJoins("Account").Find(&user) +func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.InnerJoin, query, args...) +} + +func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + + if len(args) == 1 { + if db, ok := args[0].(*DB); ok { + j := join{ + Name: query, Conds: args, Selects: db.Statement.Selects, + Omits: db.Statement.Omits, JoinType: joinType, + } + if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + j.On = &where + } + tx.Statement.Joins = append(tx.Statement.Joins, j) + return + } + } + + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) + return +} + +// Group specify the group method on the find +// +// // Select the sum age of users with given names +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results) +func (db *DB) Group(name string) (tx *DB) { + tx = db.getInstance() + + fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + tx.Statement.AddClause(clause.GroupBy{ + Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, + }) + return +} + +// Having specify HAVING conditions for GROUP BY +// +// // Select the sum age of users with name jinzhu +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result) +func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.GroupBy{ + Having: tx.Statement.BuildCondition(query, args...), + }) + return +} + +// Order specify order when retrieving records from database +// +// db.Order("name DESC") +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +func (db *DB) Order(value interface{}) (tx *DB) { + tx = db.getInstance() + + switch v := value.(type) { + case clause.OrderByColumn: + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{v}, + }) + case string: + if v != "" { + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{{ + Column: clause.Column{Name: v, Raw: true}, + }}, + }) + } + } + return +} + +// Limit specify the number of records to be retrieved +// +// Limit conditions can be cancelled by using `Limit(-1)`. +// +// // retrieve 3 users +// db.Limit(3).Find(&users) +// // retrieve 3 users into users1, and all users into users2 +// db.Limit(3).Find(&users1).Limit(-1).Find(&users2) +func (db *DB) Limit(limit int) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Limit: &limit}) + return +} + +// Offset specify the number of records to skip before starting to return the records +// +// Offset conditions can be cancelled by using `Offset(-1)`. +// +// // select the third user +// db.Offset(2).First(&user) +// // select the first user by cancelling an earlier chained offset +// db.Offset(5).Offset(-1).First(&user) +func (db *DB) Offset(offset int) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Offset: offset}) + return +} + +// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically +// +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { + tx = db.getInstance() + tx.Statement.scopes = append(tx.Statement.scopes, funcs...) + return tx +} + +func (db *DB) executeScopes() (tx *DB) { + tx = db.getInstance() + scopes := db.Statement.scopes + if len(scopes) == 0 { + return tx + } + tx.Statement.scopes = nil + + conditions := make([]clause.Interface, 0, 4) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + + for _, scope := range scopes { + tx = scope(tx) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + } + + for _, condition := range conditions { + tx.Statement.AddClause(condition) + } + return tx +} + +// Preload preload associations with given conditions +// +// // get all users, and preload all non-cancelled orders +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Preloads == nil { + tx.Statement.Preloads = map[string][]interface{}{} + } + tx.Statement.Preloads[query] = args + return +} + +// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Attrs only adds attributes if the record is not found. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign an email if the record is not found, otherwise ignore provided email +// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit +func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.attrs = attrs + return +} + +// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that +// records will be updated even if they are found. +// +// // assign an email regardless of if the record is not found +// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit +func (db *DB) Assign(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.assigns = attrs + return +} + +func (db *DB) Unscoped() (tx *DB) { + tx = db.getInstance() + tx.Statement.Unscoped = true + return +} + +func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + return +} diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go new file mode 100644 index 00000000..34d5df41 --- /dev/null +++ b/clause/benchmarks_test.go @@ -0,0 +1,58 @@ +package clause_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func BenchmarkSelect(b *testing.B) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + + for i := 0; i < b.N; i++ { + stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clauses := []clause.Interface{clause.Select{}, clause.From{}, clause.Where{Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}}} + + for _, clause := range clauses { + stmt.AddClause(clause) + } + + stmt.Build("SELECT", "FROM", "WHERE") + _ = stmt.SQL.String() + } +} + +func BenchmarkComplexSelect(b *testing.B) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + + limit10 := 10 + for i := 0; i < b.N; i++ { + stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clauses := []clause.Interface{ + clause.Select{}, + clause.From{}, + clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.Gt{Column: "age", Value: 18}, + clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), + }}, + clause.Where{Exprs: []clause.Expression{ + clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), + }}, + clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, + clause.Limit{Limit: &limit10, Offset: 20}, + clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, + } + + for _, clause := range clauses { + stmt.AddClause(clause) + } + + stmt.Build("SELECT", "FROM", "WHERE", "GROUP BY", "LIMIT", "ORDER BY") + _ = stmt.SQL.String() + } +} diff --git a/clause/clause.go b/clause/clause.go new file mode 100644 index 00000000..1354fc05 --- /dev/null +++ b/clause/clause.go @@ -0,0 +1,89 @@ +package clause + +// Interface clause interface +type Interface interface { + Name() string + Build(Builder) + MergeClause(*Clause) +} + +// ClauseBuilder clause builder, allows to customize how to build clause +type ClauseBuilder func(Clause, Builder) + +type Writer interface { + WriteByte(byte) error + WriteString(string) (int, error) +} + +// Builder builder interface +type Builder interface { + Writer + WriteQuoted(field interface{}) + AddVar(Writer, ...interface{}) + AddError(error) error +} + +// Clause +type Clause struct { + Name string // WHERE + BeforeExpression Expression + AfterNameExpression Expression + AfterExpression Expression + Expression Expression + Builder ClauseBuilder +} + +// Build build clause +func (c Clause) Build(builder Builder) { + if c.Builder != nil { + c.Builder(c, builder) + } else if c.Expression != nil { + if c.BeforeExpression != nil { + c.BeforeExpression.Build(builder) + builder.WriteByte(' ') + } + + if c.Name != "" { + builder.WriteString(c.Name) + builder.WriteByte(' ') + } + + if c.AfterNameExpression != nil { + c.AfterNameExpression.Build(builder) + builder.WriteByte(' ') + } + + c.Expression.Build(builder) + + if c.AfterExpression != nil { + builder.WriteByte(' ') + c.AfterExpression.Build(builder) + } + } +} + +const ( + PrimaryKey string = "~~~py~~~" // primary key + CurrentTable string = "~~~ct~~~" // current table + Associations string = "~~~as~~~" // associations +) + +var ( + currentTable = Table{Name: CurrentTable} + PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey} +) + +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool +} + +// Table quote with name +type Table struct { + Name string + Alias string + Raw bool +} diff --git a/clause/clause_test.go b/clause/clause_test.go new file mode 100644 index 00000000..6239ff39 --- /dev/null +++ b/clause/clause_test.go @@ -0,0 +1,43 @@ +package clause_test + +import ( + "reflect" + "strings" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +var db, _ = gorm.Open(tests.DummyDialector{}, nil) + +func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) { + var ( + buildNames []string + buildNamesMap = map[string]bool{} + user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + ) + + for _, c := range clauses { + if _, ok := buildNamesMap[c.Name()]; !ok { + buildNames = append(buildNames, c.Name()) + buildNamesMap[c.Name()] = true + } + + stmt.AddClause(c) + } + + stmt.Build(buildNames...) + + if strings.TrimSpace(stmt.SQL.String()) != result { + t.Errorf("SQL expects %v got %v", result, stmt.SQL.String()) + } + + if !reflect.DeepEqual(stmt.Vars, vars) { + t.Errorf("Vars expects %+v got %v", stmt.Vars, vars) + } +} diff --git a/clause/delete.go b/clause/delete.go new file mode 100644 index 00000000..fc462cd7 --- /dev/null +++ b/clause/delete.go @@ -0,0 +1,23 @@ +package clause + +type Delete struct { + Modifier string +} + +func (d Delete) Name() string { + return "DELETE" +} + +func (d Delete) Build(builder Builder) { + builder.WriteString("DELETE") + + if d.Modifier != "" { + builder.WriteByte(' ') + builder.WriteString(d.Modifier) + } +} + +func (d Delete) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = d +} diff --git a/clause/delete_test.go b/clause/delete_test.go new file mode 100644 index 00000000..a9a659b3 --- /dev/null +++ b/clause/delete_test.go @@ -0,0 +1,31 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestDelete(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Delete{}, clause.From{}}, + "DELETE FROM `users`", nil, + }, + { + []clause.Interface{clause.Delete{Modifier: "LOW_PRIORITY"}, clause.From{}}, + "DELETE LOW_PRIORITY FROM `users`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/expression.go b/clause/expression.go new file mode 100644 index 00000000..92ac7f22 --- /dev/null +++ b/clause/expression.go @@ -0,0 +1,381 @@ +package clause + +import ( + "database/sql" + "database/sql/driver" + "go/ast" + "reflect" +) + +// Expression expression interface +type Expression interface { + Build(builder Builder) +} + +// NegationExpressionBuilder negation expression builder +type NegationExpressionBuilder interface { + NegationBuild(builder Builder) +} + +// Expr raw expression +type Expr struct { + SQL string + Vars []interface{} + WithoutParentheses bool +} + +// Build build raw expression +func (expr Expr) Build(builder Builder) { + var ( + afterParenthesis bool + idx int + ) + + for _, v := range []byte(expr.SQL) { + if v == '?' && len(expr.Vars) > idx { + if afterParenthesis || expr.WithoutParentheses { + if _, ok := expr.Vars[idx].(driver.Valuer); ok { + builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + + idx++ + } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } + builder.WriteByte(v) + } + } + + if idx < len(expr.Vars) { + for _, v := range expr.Vars[idx:] { + builder.AddVar(builder, sql.NamedArg{Value: v}) + } + } +} + +// NamedExpr raw expression for named expr +type NamedExpr struct { + SQL string + Vars []interface{} +} + +// Build build raw expression +func (expr NamedExpr) Build(builder Builder) { + var ( + idx int + inName bool + afterParenthesis bool + namedMap = make(map[string]interface{}, len(expr.Vars)) + ) + + for _, v := range expr.Vars { + switch value := v.(type) { + case sql.NamedArg: + namedMap[value.Name] = value.Value + case map[string]interface{}: + for k, v := range value { + namedMap[k] = v + } + default: + var appendFieldsToMap func(reflect.Value) + appendFieldsToMap = func(reflectValue reflect.Value) { + reflectValue = reflect.Indirect(reflectValue) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + + if fieldStruct.Anonymous { + appendFieldsToMap(reflectValue.Field(i)) + } + } + } + } + } + + appendFieldsToMap(reflect.ValueOf(value)) + } + } + + name := make([]byte, 0, 10) + + for _, v := range []byte(expr.SQL) { + if v == '@' && !inName { + inName = true + name = []byte{} + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { + if inName { + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } + inName = false + } + + afterParenthesis = false + builder.WriteByte(v) + } else if v == '?' && len(expr.Vars) > idx { + if afterParenthesis { + if _, ok := expr.Vars[idx].(driver.Valuer); ok { + builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + + idx++ + } else if inName { + name = append(name, v) + } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } + builder.WriteByte(v) + } + } + + if inName { + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } + } +} + +// IN Whether a value is within a set of values +type IN struct { + Column interface{} + Values []interface{} +} + +func (in IN) Build(builder Builder) { + builder.WriteQuoted(in.Column) + + switch len(in.Values) { + case 0: + builder.WriteString(" IN (NULL)") + case 1: + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteString(" = ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough + default: + builder.WriteString(" IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') + } +} + +func (in IN) NegationBuild(builder Builder) { + builder.WriteQuoted(in.Column) + switch len(in.Values) { + case 0: + builder.WriteString(" IS NOT NULL") + case 1: + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteString(" <> ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough + default: + builder.WriteString(" NOT IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') + } +} + +// Eq equal to for where +type Eq struct { + Column interface{} + Value interface{} +} + +func (eq Eq) Build(builder Builder) { + builder.WriteQuoted(eq.Column) + + switch eq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + builder.WriteString(" IN (") + rv := reflect.ValueOf(eq.Value) + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + builder.WriteByte(')') + default: + if eqNil(eq.Value) { + builder.WriteString(" IS NULL") + } else { + builder.WriteString(" = ") + builder.AddVar(builder, eq.Value) + } + } +} + +func (eq Eq) NegationBuild(builder Builder) { + Neq(eq).Build(builder) +} + +// Neq not equal to for where +type Neq Eq + +func (neq Neq) Build(builder Builder) { + builder.WriteQuoted(neq.Column) + + switch neq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + builder.WriteString(" NOT IN (") + rv := reflect.ValueOf(neq.Value) + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + builder.WriteByte(')') + default: + if eqNil(neq.Value) { + builder.WriteString(" IS NOT NULL") + } else { + builder.WriteString(" <> ") + builder.AddVar(builder, neq.Value) + } + } +} + +func (neq Neq) NegationBuild(builder Builder) { + Eq(neq).Build(builder) +} + +// Gt greater than for where +type Gt Eq + +func (gt Gt) Build(builder Builder) { + builder.WriteQuoted(gt.Column) + builder.WriteString(" > ") + builder.AddVar(builder, gt.Value) +} + +func (gt Gt) NegationBuild(builder Builder) { + Lte(gt).Build(builder) +} + +// Gte greater than or equal to for where +type Gte Eq + +func (gte Gte) Build(builder Builder) { + builder.WriteQuoted(gte.Column) + builder.WriteString(" >= ") + builder.AddVar(builder, gte.Value) +} + +func (gte Gte) NegationBuild(builder Builder) { + Lt(gte).Build(builder) +} + +// Lt less than for where +type Lt Eq + +func (lt Lt) Build(builder Builder) { + builder.WriteQuoted(lt.Column) + builder.WriteString(" < ") + builder.AddVar(builder, lt.Value) +} + +func (lt Lt) NegationBuild(builder Builder) { + Gte(lt).Build(builder) +} + +// Lte less than or equal to for where +type Lte Eq + +func (lte Lte) Build(builder Builder) { + builder.WriteQuoted(lte.Column) + builder.WriteString(" <= ") + builder.AddVar(builder, lte.Value) +} + +func (lte Lte) NegationBuild(builder Builder) { + Gt(lte).Build(builder) +} + +// Like whether string matches regular expression +type Like Eq + +func (like Like) Build(builder Builder) { + builder.WriteQuoted(like.Column) + builder.WriteString(" LIKE ") + builder.AddVar(builder, like.Value) +} + +func (like Like) NegationBuild(builder Builder) { + builder.WriteQuoted(like.Column) + builder.WriteString(" NOT LIKE ") + builder.AddVar(builder, like.Value) +} + +func eqNil(value interface{}) bool { + if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) { + value, _ = valuer.Value() + } + + return value == nil || eqNilReflect(value) +} + +func eqNilReflect(value interface{}) bool { + reflectValue := reflect.ValueOf(value) + return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() +} diff --git a/clause/expression_test.go b/clause/expression_test.go new file mode 100644 index 00000000..aaede61c --- /dev/null +++ b/clause/expression_test.go @@ -0,0 +1,232 @@ +package clause_test + +import ( + "database/sql" + "fmt" + "reflect" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestExpr(t *testing.T) { + results := []struct { + SQL string + Result string + Vars []interface{} + }{{ + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, + Result: "create table `users` (`id` int, `name` text)", + }} + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + }) + } +} + +func TestNamedExpr(t *testing.T) { + type Base struct { + Name2 string + } + + type NamedArgument struct { + Name1 string + Base + } + + results := []struct { + SQL string + Result string + Vars []interface{} + ExpectedVars []interface{} + }{{ + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, + Result: "create table `users` (`id` int, `name` text)", + }, { + SQL: "name1 = @name AND name2 = @name", + Vars: []interface{}{sql.Named("name", "jinzhu")}, + Result: "name1 = ? AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name AND name2 = @@name", + Vars: []interface{}{map[string]interface{}{"name": "jinzhu"}}, + Result: "name1 = ? AND name2 = @@name", + ExpectedVars: []interface{}{"jinzhu"}, + }, { + SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", + Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, + Result: "name1 = ? AND name2 = ? AND name3 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}}, + Result: "name1 = ? AND name2 = ? AND name3 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist", + Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @notexist", + Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{}, + Result: "create table ? (? ?, ? ?)", + }, { + SQL: "name1 = @name AND name2 = @name;", + Vars: []interface{}{sql.Named("name", "jinzhu")}, + Result: "name1 = ? AND name2 = ?;", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1\r\n AND name2 = @name2", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, + Result: "name1 = ?\r\n AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1\r AND name2 = @name2", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, + Result: "name1 = ?\r AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col"}}, + Result: "`table`.`col`", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col", Raw: true}}, + Result: "table.col", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: clause.PrimaryKey, Raw: true}}, + Result: "table.id", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias"}}, + Result: "`table`.`col` AS `alias`", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias", Raw: true}}, + Result: "table.col AS alias", + }, { + SQL: "?", + Vars: []interface{}{clause.Table{Name: "table", Alias: "alias"}}, + Result: "`table` `alias`", + }, { + SQL: "?", + Vars: []interface{}{clause.Table{Name: "table", Alias: "alias", Raw: true}}, + Result: "table alias", + }} + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + + if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { + t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) + } + }) + } +} + +func TestExpression(t *testing.T) { + column := "column-name" + results := []struct { + Expressions []clause.Expression + ExpectedVars []interface{} + Result string + }{{ + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: "column-value"}, + }, + ExpectedVars: []interface{}{"column-value"}, + Result: "`column-name` = ?", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: nil}, + clause.Eq{Column: column, Value: (*string)(nil)}, + clause.Eq{Column: column, Value: (*int)(nil)}, + clause.Eq{Column: column, Value: (*bool)(nil)}, + clause.Eq{Column: column, Value: (interface{})(nil)}, + clause.Eq{Column: column, Value: sql.NullString{String: "", Valid: false}}, + }, + Result: "`column-name` IS NULL", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: "column-value"}, + }, + ExpectedVars: []interface{}{"column-value"}, + Result: "`column-name` <> ?", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: nil}, + clause.Neq{Column: column, Value: (*string)(nil)}, + clause.Neq{Column: column, Value: (*int)(nil)}, + clause.Neq{Column: column, Value: (*bool)(nil)}, + clause.Neq{Column: column, Value: (interface{})(nil)}, + }, + Result: "`column-name` IS NOT NULL", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: []string{"a", "b"}}, + }, + ExpectedVars: []interface{}{"a", "b"}, + Result: "`column-name` IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: []string{"a", "b"}}, + }, + ExpectedVars: []interface{}{"a", "b"}, + Result: "`column-name` NOT IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, + }, + ExpectedVars: []interface{}{100}, + Result: "SUM(`id`) = ?", + }, { + Expressions: []clause.Expression{ + clause.Gte{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Table: "users", Name: "id"}}}, Value: 100}, + }, + ExpectedVars: []interface{}{100}, + Result: "SUM(`users`.`id`) >= ?", + }} + + for idx, result := range results { + for idy, expression := range result.Expressions { + t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + expression.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + + if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { + t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) + } + }) + } + } +} diff --git a/clause/from.go b/clause/from.go new file mode 100644 index 00000000..1ea2d595 --- /dev/null +++ b/clause/from.go @@ -0,0 +1,37 @@ +package clause + +// From from clause +type From struct { + Tables []Table + Joins []Join +} + +// Name from clause name +func (from From) Name() string { + return "FROM" +} + +// Build build from clause +func (from From) Build(builder Builder) { + if len(from.Tables) > 0 { + for idx, table := range from.Tables { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(table) + } + } else { + builder.WriteQuoted(currentTable) + } + + for _, join := range from.Joins { + builder.WriteByte(' ') + join.Build(builder) + } +} + +// MergeClause merge from clause +func (from From) MergeClause(clause *Clause) { + clause.Expression = from +} diff --git a/clause/from_test.go b/clause/from_test.go new file mode 100644 index 00000000..75422f8e --- /dev/null +++ b/clause/from_test.go @@ -0,0 +1,75 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestFrom(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Type: clause.InnerJoin, + Table: clause.Table{Name: "articles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, + }, + }, + }, + }, + }, + "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Type: clause.RightJoin, + Table: clause.Table{Name: "profiles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}}, + }, + }, + }, + }, clause.From{ + Joins: []clause.Join{ + { + Type: clause.InnerJoin, + Table: clause.Table{Name: "articles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, + }, + }, { + Type: clause.LeftJoin, + Table: clause.Table{Name: "companies"}, + Using: []string{"company_name"}, + }, + }, + }, + }, + "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`)", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/group_by.go b/clause/group_by.go new file mode 100644 index 00000000..84242fb8 --- /dev/null +++ b/clause/group_by.go @@ -0,0 +1,48 @@ +package clause + +// GroupBy group by clause +type GroupBy struct { + Columns []Column + Having []Expression +} + +// Name from clause name +func (groupBy GroupBy) Name() string { + return "GROUP BY" +} + +// Build build group by clause +func (groupBy GroupBy) Build(builder Builder) { + for idx, column := range groupBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } + + if len(groupBy.Having) > 0 { + builder.WriteString(" HAVING ") + Where{Exprs: groupBy.Having}.Build(builder) + } +} + +// MergeClause merge group by clause +func (groupBy GroupBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(GroupBy); ok { + copiedColumns := make([]Column, len(v.Columns)) + copy(copiedColumns, v.Columns) + groupBy.Columns = append(copiedColumns, groupBy.Columns...) + + copiedHaving := make([]Expression, len(v.Having)) + copy(copiedHaving, v.Having) + groupBy.Having = append(copiedHaving, groupBy.Having...) + } + clause.Expression = groupBy + + if len(groupBy.Columns) == 0 { + clause.Name = "" + } else { + clause.Name = groupBy.Name() + } +} diff --git a/clause/group_by_test.go b/clause/group_by_test.go new file mode 100644 index 00000000..7c282cb9 --- /dev/null +++ b/clause/group_by_test.go @@ -0,0 +1,42 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestGroupBy(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ + Columns: []clause.Column{{Name: "role"}}, + Having: []clause.Expression{clause.Eq{"role", "admin"}}, + }}, + "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", + []interface{}{"admin"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ + Columns: []clause.Column{{Name: "role"}}, + Having: []clause.Expression{clause.Eq{"role", "admin"}}, + }, clause.GroupBy{ + Columns: []clause.Column{{Name: "gender"}}, + Having: []clause.Expression{clause.Neq{"gender", "U"}}, + }}, + "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", + []interface{}{"admin", "U"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/insert.go b/clause/insert.go new file mode 100644 index 00000000..8efaa035 --- /dev/null +++ b/clause/insert.go @@ -0,0 +1,39 @@ +package clause + +type Insert struct { + Table Table + Modifier string +} + +// Name insert clause name +func (insert Insert) Name() string { + return "INSERT" +} + +// Build build insert clause +func (insert Insert) Build(builder Builder) { + if insert.Modifier != "" { + builder.WriteString(insert.Modifier) + builder.WriteByte(' ') + } + + builder.WriteString("INTO ") + if insert.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(insert.Table) + } +} + +// MergeClause merge insert clause +func (insert Insert) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Insert); ok { + if insert.Modifier == "" { + insert.Modifier = v.Modifier + } + if insert.Table.Name == "" { + insert.Table = v.Table + } + } + clause.Expression = insert +} diff --git a/clause/insert_test.go b/clause/insert_test.go new file mode 100644 index 00000000..70810bce --- /dev/null +++ b/clause/insert_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestInsert(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Insert{}}, + "INSERT INTO `users`", nil, + }, + { + []clause.Interface{clause.Insert{Modifier: "LOW_PRIORITY"}}, + "INSERT LOW_PRIORITY INTO `users`", nil, + }, + { + []clause.Interface{clause.Insert{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, + "INSERT LOW_PRIORITY INTO `products`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/joins.go b/clause/joins.go new file mode 100644 index 00000000..879892be --- /dev/null +++ b/clause/joins.go @@ -0,0 +1,47 @@ +package clause + +type JoinType string + +const ( + CrossJoin JoinType = "CROSS" + InnerJoin JoinType = "INNER" + LeftJoin JoinType = "LEFT" + RightJoin JoinType = "RIGHT" +) + +// Join clause for from +type Join struct { + Type JoinType + Table Table + ON Where + Using []string + Expression Expression +} + +func (join Join) Build(builder Builder) { + if join.Expression != nil { + join.Expression.Build(builder) + } else { + if join.Type != "" { + builder.WriteString(string(join.Type)) + builder.WriteByte(' ') + } + + builder.WriteString("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.WriteString(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.WriteString(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') + } + } +} diff --git a/clause/joins_test.go b/clause/joins_test.go new file mode 100644 index 00000000..f1f20ec3 --- /dev/null +++ b/clause/joins_test.go @@ -0,0 +1,101 @@ +package clause_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestJoin(t *testing.T) { + results := []struct { + name string + join clause.Join + sql string + }{ + { + name: "LEFT JOIN", + join: clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "RIGHT JOIN", + join: clause.Join{ + Type: clause.RightJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "INNER JOIN", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "CROSS JOIN", + join: clause.Join{ + Type: clause.CrossJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "USING", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + { + name: "Expression", + join: clause.Join{ + // Invalid + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + // Valid + Expression: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + } + for _, result := range results { + t.Run(result.name, func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + result.join.Build(stmt) + if result.sql != stmt.SQL.String() { + t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String()) + } + }) + } +} diff --git a/clause/limit.go b/clause/limit.go new file mode 100644 index 00000000..abda0055 --- /dev/null +++ b/clause/limit.go @@ -0,0 +1,48 @@ +package clause + +import "strconv" + +// Limit limit clause +type Limit struct { + Limit *int + Offset int +} + +// Name where clause name +func (limit Limit) Name() string { + return "LIMIT" +} + +// Build build where clause +func (limit Limit) Build(builder Builder) { + if limit.Limit != nil && *limit.Limit >= 0 { + builder.WriteString("LIMIT ") + builder.WriteString(strconv.Itoa(*limit.Limit)) + } + if limit.Offset > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { + builder.WriteByte(' ') + } + builder.WriteString("OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) + } +} + +// MergeClause merge order by clauses +func (limit Limit) MergeClause(clause *Clause) { + clause.Name = "" + + if v, ok := clause.Expression.(Limit); ok { + if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil { + limit.Limit = v.Limit + } + + if limit.Offset == 0 && v.Offset > 0 { + limit.Offset = v.Offset + } else if limit.Offset < 0 { + limit.Offset = 0 + } + } + + clause.Expression = limit +} diff --git a/clause/limit_test.go b/clause/limit_test.go new file mode 100644 index 00000000..a9fd4e24 --- /dev/null +++ b/clause/limit_test.go @@ -0,0 +1,70 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestLimit(t *testing.T) { + limit0 := 0 + limit10 := 10 + limit50 := 50 + limitNeg10 := -10 + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ + Limit: &limit10, + Offset: 20, + }}, + "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, + "SELECT * FROM `users` LIMIT 0", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}}, + "SELECT * FROM `users` LIMIT 0", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, + "SELECT * FROM `users` OFFSET 20", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}}, + "SELECT * FROM `users` OFFSET 30", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, + "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, + "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, + "SELECT * FROM `users` LIMIT 10", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, + "SELECT * FROM `users` OFFSET 30", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, + "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/locking.go b/clause/locking.go new file mode 100644 index 00000000..290aac92 --- /dev/null +++ b/clause/locking.go @@ -0,0 +1,31 @@ +package clause + +type Locking struct { + Strength string + Table Table + Options string +} + +// Name where clause name +func (locking Locking) Name() string { + return "FOR" +} + +// Build build where clause +func (locking Locking) Build(builder Builder) { + builder.WriteString(locking.Strength) + if locking.Table.Name != "" { + builder.WriteString(" OF ") + builder.WriteQuoted(locking.Table) + } + + if locking.Options != "" { + builder.WriteByte(' ') + builder.WriteString(locking.Options) + } +} + +// MergeClause merge order by clauses +func (locking Locking) MergeClause(clause *Clause) { + clause.Expression = locking +} diff --git a/clause/locking_test.go b/clause/locking_test.go new file mode 100644 index 00000000..0e607312 --- /dev/null +++ b/clause/locking_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestLocking(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}}, + "SELECT * FROM `users` FOR UPDATE", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + "SELECT * FROM `users` FOR SHARE OF `users`", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}}, + "SELECT * FROM `users` FOR UPDATE NOWAIT", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/on_conflict.go b/clause/on_conflict.go new file mode 100644 index 00000000..032bf4a1 --- /dev/null +++ b/clause/on_conflict.go @@ -0,0 +1,59 @@ +package clause + +type OnConflict struct { + Columns []Column + Where Where + TargetWhere Where + OnConstraint string + DoNothing bool + DoUpdates Set + UpdateAll bool +} + +func (OnConflict) Name() string { + return "ON CONFLICT" +} + +// Build build onConflict clause +func (onConflict OnConflict) Build(builder Builder) { + if onConflict.OnConstraint != "" { + builder.WriteString("ON CONSTRAINT ") + builder.WriteString(onConflict.OnConstraint) + builder.WriteByte(' ') + } else { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } + } + + if onConflict.DoNothing { + builder.WriteString("DO NOTHING") + } else { + builder.WriteString("DO UPDATE SET ") + onConflict.DoUpdates.Build(builder) + } + + if len(onConflict.Where.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.Where.Build(builder) + builder.WriteByte(' ') + } +} + +// MergeClause merge onConflict clauses +func (onConflict OnConflict) MergeClause(clause *Clause) { + clause.Expression = onConflict +} diff --git a/clause/order_by.go b/clause/order_by.go new file mode 100644 index 00000000..41218025 --- /dev/null +++ b/clause/order_by.go @@ -0,0 +1,54 @@ +package clause + +type OrderByColumn struct { + Column Column + Desc bool + Reorder bool +} + +type OrderBy struct { + Columns []OrderByColumn + Expression Expression +} + +// Name where clause name +func (orderBy OrderBy) Name() string { + return "ORDER BY" +} + +// Build build where clause +func (orderBy OrderBy) Build(builder Builder) { + if orderBy.Expression != nil { + orderBy.Expression.Build(builder) + } else { + for idx, column := range orderBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column.Column) + if column.Desc { + builder.WriteString(" DESC") + } + } + } +} + +// MergeClause merge order by clauses +func (orderBy OrderBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(OrderBy); ok { + for i := len(orderBy.Columns) - 1; i >= 0; i-- { + if orderBy.Columns[i].Reorder { + orderBy.Columns = orderBy.Columns[i:] + clause.Expression = orderBy + return + } + } + + copiedColumns := make([]OrderByColumn, len(v.Columns)) + copy(copiedColumns, v.Columns) + orderBy.Columns = append(copiedColumns, orderBy.Columns...) + } + + clause.Expression = orderBy +} diff --git a/clause/order_by_test.go b/clause/order_by_test.go new file mode 100644 index 00000000..d8b5dfbf --- /dev/null +++ b/clause/order_by_test.go @@ -0,0 +1,58 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestOrderBy(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }}, + "SELECT * FROM `users` ORDER BY `users`.`id` DESC", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}}}, + }, + }, + "SELECT * FROM `users` ORDER BY `users`.`id` DESC,`name`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}, Reorder: true}}, + }, + }, + "SELECT * FROM `users` ORDER BY `name`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }, + }, + "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", + []interface{}{1, 2, 3}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/returning.go b/clause/returning.go new file mode 100644 index 00000000..d94b7a4c --- /dev/null +++ b/clause/returning.go @@ -0,0 +1,34 @@ +package clause + +type Returning struct { + Columns []Column +} + +// Name where clause name +func (returning Returning) Name() string { + return "RETURNING" +} + +// Build build where clause +func (returning Returning) Build(builder Builder) { + if len(returning.Columns) > 0 { + for idx, column := range returning.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') + } +} + +// MergeClause merge order by clauses +func (returning Returning) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Returning); ok { + returning.Columns = append(v.Columns, returning.Columns...) + } + + clause.Expression = returning +} diff --git a/clause/returning_test.go b/clause/returning_test.go new file mode 100644 index 00000000..bd0ecce8 --- /dev/null +++ b/clause/returning_test.go @@ -0,0 +1,36 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestReturning(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ + []clause.Column{clause.PrimaryColumn}, + }}, + "SELECT * FROM `users` RETURNING `users`.`id`", nil, + }, { + []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ + []clause.Column{clause.PrimaryColumn}, + }, clause.Returning{ + []clause.Column{{Name: "name"}, {Name: "age"}}, + }}, + "SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/select.go b/clause/select.go new file mode 100644 index 00000000..d8e9f801 --- /dev/null +++ b/clause/select.go @@ -0,0 +1,59 @@ +package clause + +// Select select attrs when querying, updating, creating +type Select struct { + Distinct bool + Columns []Column + Expression Expression +} + +func (s Select) Name() string { + return "SELECT" +} + +func (s Select) Build(builder Builder) { + if len(s.Columns) > 0 { + if s.Distinct { + builder.WriteString("DISTINCT ") + } + + for idx, column := range s.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') + } +} + +func (s Select) MergeClause(clause *Clause) { + if s.Expression != nil { + if s.Distinct { + if expr, ok := s.Expression.(Expr); ok { + expr.SQL = "DISTINCT " + expr.SQL + clause.Expression = expr + return + } + } + + clause.Expression = s.Expression + } else { + clause.Expression = s + } +} + +// CommaExpression represents a group of expressions separated by commas. +type CommaExpression struct { + Exprs []Expression +} + +func (comma CommaExpression) Build(builder Builder) { + for idx, expr := range comma.Exprs { + if idx > 0 { + _, _ = builder.WriteString(", ") + } + expr.Build(builder) + } +} diff --git a/clause/select_test.go b/clause/select_test.go new file mode 100644 index 00000000..9c11b90d --- /dev/null +++ b/clause/select_test.go @@ -0,0 +1,72 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestSelect(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Columns: []clause.Column{clause.PrimaryColumn}, + }, clause.From{}}, + "SELECT `users`.`id` FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Columns: []clause.Column{clause.PrimaryColumn}, + }, clause.Select{ + Columns: []clause.Column{{Name: "name"}}, + }, clause.From{}}, + "SELECT `name` FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Expression: clause.CommaExpression{ + Exprs: []clause.Expression{ + clause.NamedExpr{"?", []interface{}{clause.Column{Name: "id"}}}, + clause.NamedExpr{"?", []interface{}{clause.Column{Name: "name"}}}, + clause.NamedExpr{"LENGTH(?)", []interface{}{clause.Column{Name: "mobile"}}}, + }, + }, + }, clause.From{}}, + "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Expression: clause.CommaExpression{ + Exprs: []clause.Expression{ + clause.Expr{ + SQL: "? as name", + Vars: []interface{}{ + clause.Eq{ + Column: clause.Column{Name: "age"}, + Value: 18, + }, + }, + }, + }, + }, + }, clause.From{}}, + "SELECT `age` = ? as name FROM `users`", + []interface{}{18}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/set.go b/clause/set.go new file mode 100644 index 00000000..75eb6bdd --- /dev/null +++ b/clause/set.go @@ -0,0 +1,60 @@ +package clause + +import "sort" + +type Set []Assignment + +type Assignment struct { + Column Column + Value interface{} +} + +func (set Set) Name() string { + return "SET" +} + +func (set Set) Build(builder Builder) { + if len(set) > 0 { + for idx, assignment := range set { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(assignment.Column) + builder.WriteByte('=') + builder.AddVar(builder, assignment.Value) + } + } else { + builder.WriteQuoted(Column{Name: PrimaryKey}) + builder.WriteByte('=') + builder.WriteQuoted(Column{Name: PrimaryKey}) + } +} + +// MergeClause merge assignments clauses +func (set Set) MergeClause(clause *Clause) { + copiedAssignments := make([]Assignment, len(set)) + copy(copiedAssignments, set) + clause.Expression = Set(copiedAssignments) +} + +func Assignments(values map[string]interface{}) Set { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + sort.Strings(keys) + + assignments := make([]Assignment, len(keys)) + for idx, key := range keys { + assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]} + } + return assignments +} + +func AssignmentColumns(values []string) Set { + assignments := make([]Assignment, len(values)) + for idx, value := range values { + assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}} + } + return assignments +} diff --git a/clause/set_test.go b/clause/set_test.go new file mode 100644 index 00000000..7a9ee895 --- /dev/null +++ b/clause/set_test.go @@ -0,0 +1,59 @@ +package clause_test + +import ( + "fmt" + "sort" + "strings" + "testing" + + "gorm.io/gorm/clause" +) + +func TestSet(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{ + clause.Update{}, + clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), + }, + "UPDATE `users` SET `users`.`id`=?", + []interface{}{1}, + }, + { + []clause.Interface{ + clause.Update{}, + clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), + clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), + }, + "UPDATE `users` SET `name`=?", + []interface{}{"jinzhu"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} + +func TestAssignments(t *testing.T) { + set := clause.Assignments(map[string]interface{}{ + "name": "jinzhu", + "age": 18, + }) + + assignments := []clause.Assignment(set) + + sort.Slice(assignments, func(i, j int) bool { + return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0 + }) + + if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 { + t.Errorf("invalid assignments, got %v", assignments) + } +} diff --git a/clause/update.go b/clause/update.go new file mode 100644 index 00000000..f9d68ac6 --- /dev/null +++ b/clause/update.go @@ -0,0 +1,38 @@ +package clause + +type Update struct { + Modifier string + Table Table +} + +// Name update clause name +func (update Update) Name() string { + return "UPDATE" +} + +// Build build update clause +func (update Update) Build(builder Builder) { + if update.Modifier != "" { + builder.WriteString(update.Modifier) + builder.WriteByte(' ') + } + + if update.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(update.Table) + } +} + +// MergeClause merge update clause +func (update Update) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Update); ok { + if update.Modifier == "" { + update.Modifier = v.Modifier + } + if update.Table.Name == "" { + update.Table = v.Table + } + } + clause.Expression = update +} diff --git a/clause/update_test.go b/clause/update_test.go new file mode 100644 index 00000000..c704bf5e --- /dev/null +++ b/clause/update_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestUpdate(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Update{}}, + "UPDATE `users`", nil, + }, + { + []clause.Interface{clause.Update{Modifier: "LOW_PRIORITY"}}, + "UPDATE LOW_PRIORITY `users`", nil, + }, + { + []clause.Interface{clause.Update{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, + "UPDATE LOW_PRIORITY `products`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/values.go b/clause/values.go new file mode 100644 index 00000000..b2f5421b --- /dev/null +++ b/clause/values.go @@ -0,0 +1,45 @@ +package clause + +type Values struct { + Columns []Column + Values [][]interface{} +} + +// Name from clause name +func (Values) Name() string { + return "VALUES" +} + +// Build build from clause +func (values Values) Build(builder Builder) { + if len(values.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range values.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteByte(')') + + builder.WriteString(" VALUES ") + + for idx, value := range values.Values { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteByte('(') + builder.AddVar(builder, value...) + builder.WriteByte(')') + } + } else { + builder.WriteString("DEFAULT VALUES") + } +} + +// MergeClause merge values clauses +func (values Values) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = values +} diff --git a/clause/values_test.go b/clause/values_test.go new file mode 100644 index 00000000..1eea8652 --- /dev/null +++ b/clause/values_test.go @@ -0,0 +1,34 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestValues(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{ + clause.Insert{}, + clause.Values{ + Columns: []clause.Column{{Name: "name"}, {Name: "age"}}, + Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, + }, + }, + "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", + []interface{}{"jinzhu", 18, "josh", 1}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/where.go b/clause/where.go new file mode 100644 index 00000000..a29401cf --- /dev/null +++ b/clause/where.go @@ -0,0 +1,190 @@ +package clause + +import ( + "strings" +) + +const ( + AndWithSpace = " AND " + OrWithSpace = " OR " +) + +// Where where clause +type Where struct { + Exprs []Expression +} + +// Name where clause name +func (where Where) Name() string { + return "WHERE" +} + +// Build build where clause +func (where Where) Build(builder Builder) { + // Switch position if the first query expression is a single Or condition + for idx, expr := range where.Exprs { + if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { + if idx != 0 { + where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] + } + break + } + } + + buildExprs(where.Exprs, builder, AndWithSpace) +} + +func buildExprs(exprs []Expression, builder Builder, joinCond string) { + wrapInParentheses := false + + for idx, expr := range exprs { + if idx > 0 { + if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { + builder.WriteString(OrWithSpace) + } else { + builder.WriteString(joinCond) + } + } + + if len(exprs) > 1 { + switch v := expr.(type) { + case OrConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) + } + } + case AndConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) + } + } + case Expr: + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) + case NamedExpr: + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) + } + } + + if wrapInParentheses { + builder.WriteByte('(') + expr.Build(builder) + builder.WriteByte(')') + wrapInParentheses = false + } else { + expr.Build(builder) + } + } +} + +// MergeClause merge where clauses +func (where Where) MergeClause(clause *Clause) { + if w, ok := clause.Expression.(Where); ok { + exprs := make([]Expression, len(w.Exprs)+len(where.Exprs)) + copy(exprs, w.Exprs) + copy(exprs[len(w.Exprs):], where.Exprs) + where.Exprs = exprs + } + + clause.Expression = where +} + +func And(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + + if len(exprs) == 1 { + if _, ok := exprs[0].(OrConditions); !ok { + return exprs[0] + } + } + + return AndConditions{Exprs: exprs} +} + +type AndConditions struct { + Exprs []Expression +} + +func (and AndConditions) Build(builder Builder) { + if len(and.Exprs) > 1 { + builder.WriteByte('(') + buildExprs(and.Exprs, builder, AndWithSpace) + builder.WriteByte(')') + } else { + buildExprs(and.Exprs, builder, AndWithSpace) + } +} + +func Or(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return OrConditions{Exprs: exprs} +} + +type OrConditions struct { + Exprs []Expression +} + +func (or OrConditions) Build(builder Builder) { + if len(or.Exprs) > 1 { + builder.WriteByte('(') + buildExprs(or.Exprs, builder, OrWithSpace) + builder.WriteByte(')') + } else { + buildExprs(or.Exprs, builder, OrWithSpace) + } +} + +func Not(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return NotConditions{Exprs: exprs} +} + +type NotConditions struct { + Exprs []Expression +} + +func (not NotConditions) Build(builder Builder) { + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } +} diff --git a/clause/where_test.go b/clause/where_test.go new file mode 100644 index 00000000..35e3dbee --- /dev/null +++ b/clause/where_test.go @@ -0,0 +1,115 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestWhere(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", + []interface{}{"1", 18, "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", + []interface{}{"1", "jinzhu", 18}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", + []interface{}{"1", "jinzhu", 18}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", + []interface{}{"1", "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", + []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", + []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, + }}, + "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", + []interface{}{18, "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{ + clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), + }, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", + []interface{}{"1", 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", + []interface{}{"1", 100}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/with.go b/clause/with.go new file mode 100644 index 00000000..0768488e --- /dev/null +++ b/clause/with.go @@ -0,0 +1,3 @@ +package clause + +type With struct{} diff --git a/create_test.go b/create_test.go deleted file mode 100644 index 2d71c9a6..00000000 --- a/create_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "testing" - "time" -) - -func TestCreate(t *testing.T) { - float := 35.03554004971999 - user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} - - if !DB.NewRecord(user) || !DB.NewRecord(&user) { - t.Error("User should be new record before create") - } - - if count := DB.Save(&user).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - if DB.NewRecord(user) || DB.NewRecord(&user) { - t.Error("User should not new record after save") - } - - var newUser User - DB.First(&newUser, user.Id) - - if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { - t.Errorf("User's PasswordHash should be saved ([]byte)") - } - - if newUser.Age != 18 { - t.Errorf("User's Age should be saved (int)") - } - - if newUser.UserNum != Num(111) { - t.Errorf("User's UserNum should be saved (custom type)") - } - - if newUser.Latitude != float { - t.Errorf("Float64 should not be changed after save") - } - - if user.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - if newUser.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - DB.Model(user).Update("name", "create_user_new_name") - DB.First(&user, user.Id) - if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) { - t.Errorf("CreatedAt should not be changed after update") - } -} - -func TestCreateWithAutoIncrement(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { - t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column") - } - user1 := User{} - user2 := User{} - - DB.Create(&user1) - DB.Create(&user2) - - if user2.Sequence-user1.Sequence != 1 { - t.Errorf("Auto increment should apply on Sequence") - } -} - -func TestCreateWithNoGORMPrimayKey(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { - t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column") - } - - jt := JoinTable{From: 1, To: 2} - err := DB.Create(&jt).Error - if err != nil { - t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) - } -} - -func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - if DB.Save(&animal).Error != nil { - t.Errorf("No error should happen when create a record without std primary key") - } - - if animal.Counter == 0 { - t.Errorf("No std primary key should be filled value after create") - } - - if animal.Name != "Ferdinand" { - t.Errorf("Default value should be overrided") - } - - // Test create with default value not overrided - an := Animal{From: "nerdz"} - - if DB.Save(&an).Error != nil { - t.Errorf("No error should happen when create an record without std primary key") - } - - // We must fetch the value again, to have the default fields updated - // (We can't do this in the update statements, since sql default can be expressions - // And be different from the fields' type (eg. a time.Time fields has a default value of "now()" - DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) - - if an.Name != "galeone" { - t.Errorf("Default value should fill the field. But got %v", an.Name) - } -} - -func TestAnonymousScanner(t *testing.T) { - user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_scanner") - if user2.Role.Name != "admin" { - t.Errorf("Should be able to get anonymous scanner") - } - - if !user2.IsAdmin() { - t.Errorf("Should be able to get anonymous scanner") - } -} - -func TestAnonymousField(t *testing.T) { - user := User{Name: "anonymous_field", Company: Company{Name: "company"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_field") - DB.Model(&user2).Related(&user2.Company) - if user2.Company.Name != "company" { - t.Errorf("Should be able to get anonymous field") - } -} - -func TestSelectWithCreate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_create") - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name != user.Name || queryuser.Age == user.Age { - t.Errorf("Should only create users with name column") - } - - if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 || - queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 { - t.Errorf("Should only create selected relationships") - } -} - -func TestOmitWithCreate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_create") - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name == user.Name || queryuser.Age != user.Age { - t.Errorf("Should only create users with age column") - } - - if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 || - queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 { - t.Errorf("Should not create omited relationships") - } -} diff --git a/customize_column_test.go b/customize_column_test.go deleted file mode 100644 index 177b4a5d..00000000 --- a/customize_column_test.go +++ /dev/null @@ -1,280 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type CustomizeColumn struct { - ID int64 `gorm:"column:mapped_id; primary_key:yes"` - Name string `gorm:"column:mapped_name"` - Date time.Time `gorm:"column:mapped_time"` -} - -// Make sure an ignored field does not interfere with another field's custom -// column name that matches the ignored field. -type CustomColumnAndIgnoredFieldClash struct { - Body string `sql:"-"` - RawBody string `gorm:"column:body"` -} - -func TestCustomizeColumn(t *testing.T) { - col := "mapped_name" - DB.DropTable(&CustomizeColumn{}) - DB.AutoMigrate(&CustomizeColumn{}) - - scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope.TableName(), col) { - t.Errorf("CustomizeColumn should have column %s", col) - } - - col = "mapped_id" - if scope.PrimaryKey() != col { - t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) - } - - expected := "foo" - cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()} - - if count := DB.Create(&cc).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - var cc1 CustomizeColumn - DB.First(&cc1, 666) - - if cc1.Name != expected { - t.Errorf("Failed to query CustomizeColumn") - } - - cc.Name = "bar" - DB.Save(&cc) - - var cc2 CustomizeColumn - DB.First(&cc2, 666) - if cc2.Name != "bar" { - t.Errorf("Failed to query CustomizeColumn") - } -} - -func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { - DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) - if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { - t.Errorf("Should not raise error: %s", err) - } -} - -type CustomizePerson struct { - IdPerson string `gorm:"column:idPerson;primary_key:true"` - Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` -} - -type CustomizeAccount struct { - IdAccount string `gorm:"column:idAccount;primary_key:true"` - Name string -} - -func TestManyToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") - DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) - - account := CustomizeAccount{IdAccount: "account", Name: "id1"} - person := CustomizePerson{ - IdPerson: "person", - Accounts: []CustomizeAccount{account}, - } - - if err := DB.Create(&account).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if err := DB.Create(&person).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - var person1 CustomizePerson - scope := DB.NewScope(nil) - if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { - t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) - } - - if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { - t.Errorf("should preload correct accounts") - } -} - -type CustomizeUser struct { - gorm.Model - Email string `sql:"column:email_address"` -} - -type CustomizeInvitation struct { - gorm.Model - Address string `sql:"column:invitation"` - Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` -} - -func TestOneToOneWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) - DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) - - user := CustomizeUser{ - Email: "hello@example.com", - } - invitation := CustomizeInvitation{ - Address: "hello@example.com", - } - - DB.Create(&user) - DB.Create(&invitation) - - var invitation2 CustomizeInvitation - if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if invitation2.Person.Email != user.Email { - t.Errorf("Should preload one to one relation with customize foreign keys") - } -} - -type PromotionDiscount struct { - gorm.Model - Name string - Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` - Rule *PromotionRule `gorm:"ForeignKey:discount_id"` - Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` -} - -type PromotionBenefit struct { - gorm.Model - Name string - PromotionID uint - Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` -} - -type PromotionCoupon struct { - gorm.Model - Code string - DiscountID uint - Discount PromotionDiscount -} - -type PromotionRule struct { - gorm.Model - Name string - Begin *time.Time - End *time.Time - DiscountID uint - Discount *PromotionDiscount -} - -func TestOneToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) - - discount := PromotionDiscount{ - Name: "Happy New Year", - Coupons: []*PromotionCoupon{ - {Code: "newyear1"}, - {Code: "newyear2"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Coupons) != 2 { - t.Errorf("should find two coupons") - } - - var coupon PromotionCoupon - if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if coupon.Discount.Name != "Happy New Year" { - t.Errorf("should preload discount from coupon") - } -} - -func TestHasOneWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) - - var begin = time.Now() - var end = time.Now().Add(24 * time.Hour) - discount := PromotionDiscount{ - Name: "Happy New Year 2", - Rule: &PromotionRule{ - Name: "time_limited", - Begin: &begin, - End: &end, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { - t.Errorf("Should be able to preload Rule") - } - - var rule PromotionRule - if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if rule.Discount.Name != "Happy New Year 2" { - t.Errorf("should preload discount from rule") - } -} - -func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) - - discount := PromotionDiscount{ - Name: "Happy New Year 3", - Benefits: []PromotionBenefit{ - {Name: "free cod"}, - {Name: "free shipping"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Benefits) != 2 { - t.Errorf("should find two benefits") - } - - var benefit PromotionBenefit - if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if benefit.Discount.Name != "Happy New Year 3" { - t.Errorf("should preload discount from coupon") - } -} diff --git a/delete_test.go b/delete_test.go deleted file mode 100644 index d3de0a6d..00000000 --- a/delete_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" -) - -func TestDelete(t *testing.T) { - user1, user2 := User{Name: "delete1"}, User{Name: "delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if err := DB.Delete(&user1).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } - - if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("Other users that not deleted should be found-able") - } -} - -func TestInlineDelete(t *testing.T) { - user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if DB.Delete(&User{}, user1.Id).Error != nil { - t.Errorf("No error should happen when delete a record") - } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } -} - -func TestSoftDelete(t *testing.T) { - type User struct { - Id int64 - Name string - DeletedAt *time.Time - } - DB.AutoMigrate(&User{}) - - user := User{Name: "soft_delete"} - DB.Save(&user) - DB.Delete(&user) - - if DB.First(&User{}, "name = ?", user.Name).Error == nil { - t.Errorf("Can't find a soft deleted record") - } - - if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) - } - - DB.Unscoped().Delete(&user) - if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { - t.Errorf("Can't find permanently deleted record") - } -} diff --git a/dialect.go b/dialect.go deleted file mode 100644 index facde0d0..00000000 --- a/dialect.go +++ /dev/null @@ -1,106 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "strconv" - "strings" -) - -// Dialect interface contains behaviors that differ across SQL database -type Dialect interface { - // GetName get dialect's name - GetName() string - - // SetDB set db for dialect - SetDB(db *sql.DB) - - // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 - BindVar(i int) string - // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name - Quote(key string) string - // DataTypeOf return data's sql type - DataTypeOf(field *StructField) string - - // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool - // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool - // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error - // HasTable check has table or not - HasTable(tableName string) bool - // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool - - // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) string - // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` - SelectFromDummyTable() string - // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIDReturningSuffix(tableName, columnName string) string - - // BuildForeignKeyName returns a foreign key name for the given table, field and reference - BuildForeignKeyName(tableName, field, dest string) string - - // CurrentDatabase return current database name - CurrentDatabase() string -} - -var dialectsMap = map[string]Dialect{} - -func newDialect(name string, db *sql.DB) Dialect { - if value, ok := dialectsMap[name]; ok { - dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) - dialect.SetDB(db) - return dialect - } - - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - commontDialect := &commonDialect{} - commontDialect.SetDB(db) - return commontDialect -} - -// RegisterDialect register new dialect -func RegisterDialect(name string, dialect Dialect) { - dialectsMap[name] = dialect -} - -// ParseFieldStructForDialect parse field struct for dialect -func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { - // Get redirected field type - var reflectType = field.Struct.Type - for reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Get redirected field value - fieldValue = reflect.Indirect(reflect.New(reflectType)) - - // Get scanner's real value - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) - } - } - getScannerValue(fieldValue) - - // Default Size - if num, ok := field.TagSettings["SIZE"]; ok { - size, _ = strconv.Atoi(num) - } else { - size = 255 - } - - // Default type from tag setting - additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { - additionalType = additionalType + " DEFAULT " + value - } - - return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType) -} diff --git a/dialect_common.go b/dialect_common.go deleted file mode 100644 index 5b5682c5..00000000 --- a/dialect_common.go +++ /dev/null @@ -1,152 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -// DefaultForeignKeyNamer contains the default foreign key name generator method -type DefaultForeignKeyNamer struct { -} - -type commonDialect struct { - db *sql.DB - DefaultForeignKeyNamer -} - -func init() { - RegisterDialect("common", &commonDialect{}) -} - -func (commonDialect) GetName() string { - return "common" -} - -func (s *commonDialect) SetDB(db *sql.DB) { - s.db = db -} - -func (commonDialect) BindVar(i int) string { - return "$$" // ? -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "BOOLEAN" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - sqlType = "INTEGER AUTO_INCREMENT" - } else { - sqlType = "INTEGER" - } - case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - sqlType = "BIGINT AUTO_INCREMENT" - } else { - sqlType = "BIGINT" - } - case reflect.Float32, reflect.Float64: - sqlType = "FLOAT" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("VARCHAR(%d)", size) - } else { - sqlType = "VARCHAR(65532)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "TIMESTAMP" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("BINARY(%d)", size) - } else { - sqlType = "BINARY(65532)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s commonDialect) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) - return count > 0 -} - -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) - return err -} - -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s commonDialect) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) - return count > 0 -} - -func (s commonDialect) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) - return count > 0 -} - -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { - if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit > 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - } - } - if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset > 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - return -} - -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} - -func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { - keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") - return keyName -} diff --git a/dialect_mysql.go b/dialect_mysql.go deleted file mode 100644 index 11b894b3..00000000 --- a/dialect_mysql.go +++ /dev/null @@ -1,146 +0,0 @@ -package gorm - -import ( - "crypto/sha1" - "fmt" - "reflect" - "regexp" - "strings" - "time" - "unicode/utf8" -) - -type mysql struct { - commonDialect -} - -func init() { - RegisterDialect("mysql", &mysql{}) -} - -func (mysql) GetName() string { - return "mysql" -} - -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) -} - -// Get Data Type for MySQL Dialect -func (mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - - // MySQL allows only one auto increment column per table, and it must - // be a KEY column. - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey { - delete(field.TagSettings, "AUTO_INCREMENT") - } - } - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "int AUTO_INCREMENT" - } else { - sqlType = "int" - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "int unsigned AUTO_INCREMENT" - } else { - sqlType = "int unsigned" - } - case reflect.Int64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "bigint AUTO_INCREMENT" - } else { - sqlType = "bigint" - } - case reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "bigint unsigned AUTO_INCREMENT" - } else { - sqlType = "bigint unsigned" - } - case reflect.Float32, reflect.Float64: - sqlType = "double" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "longtext" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = "timestamp" - } else { - sqlType = "timestamp NULL" - } - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "longblob" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { - keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) - if utf8.RuneCountInString(keyName) <= 64 { - return keyName - } - h := sha1.New() - h.Write([]byte(keyName)) - bs := h.Sum(nil) - - // sha1 is 40 digits, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) - if len(destRunes) > 24 { - destRunes = destRunes[:24] - } - - return fmt.Sprintf("%s%x", string(destRunes), bs) -} diff --git a/dialect_postgres.go b/dialect_postgres.go deleted file mode 100644 index 5a6114c0..00000000 --- a/dialect_postgres.go +++ /dev/null @@ -1,134 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type postgres struct { - commonDialect -} - -func init() { - RegisterDialect("postgres", &postgres{}) -} - -func (postgres) GetName() string { - return "postgres" -} - -func (postgres) BindVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "serial" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "bigserial" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "numeric" - case reflect.String: - if _, ok := field.TagSettings["SIZE"]; !ok { - size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different - } - - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "timestamp with time zone" - } - case reflect.Map: - if dataValue.Type().Name() == "Hstore" { - sqlType = "hstore" - } - default: - if isByteArrayOrSlice(dataValue) { - sqlType = "bytea" - } else if isUUID(dataValue) { - sqlType = "uuid" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s postgres) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) - return count > 0 -} - -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s postgres) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) - return count > 0 -} - -func (s postgres) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) - return count > 0 -} - -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) - return -} - -func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (postgres) SupportLastInsertID() bool { - return false -} - -func isByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} - -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go deleted file mode 100644 index 2abcefa5..00000000 --- a/dialect_sqlite3.go +++ /dev/null @@ -1,108 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type sqlite3 struct { - commonDialect -} - -func init() { - RegisterDialect("sqlite", &sqlite3{}) - RegisterDialect("sqlite3", &sqlite3{}) -} - -func (sqlite3) GetName() string { - return "sqlite3" -} - -// Get Data Type for Sqlite Dialect -func (sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "integer primary key autoincrement" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "integer primary key autoincrement" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "real" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - sqlType = "blob" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s sqlite3) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) CurrentDatabase() (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go deleted file mode 100644 index a7bca6b8..00000000 --- a/dialects/mssql/mssql.go +++ /dev/null @@ -1,151 +0,0 @@ -package mssql - -import ( - "database/sql" - "fmt" - "reflect" - "strconv" - "strings" - "time" - - _ "github.com/denisenkom/go-mssqldb" - "github.com/jinzhu/gorm" -) - -func setIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) - } -} - -func init() { - gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) - gorm.RegisterDialect("mssql", &mssql{}) -} - -type mssql struct { - db *sql.DB - gorm.DefaultForeignKeyNamer -} - -func (mssql) GetName() string { - return "mssql" -} - -func (s *mssql) SetDB(db *sql.DB) { - s.db = db -} - -func (mssql) BindVar(i int) string { - return "$$" // ? -} - -func (mssql) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (mssql) DataTypeOf(field *gorm.StructField) string { - var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bit" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - sqlType = "int IDENTITY(1,1)" - } else { - sqlType = "int" - } - case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - sqlType = "bigint IDENTITY(1,1)" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "float" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("nvarchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime2" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mssql) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) - return count > 0 -} - -func (s mssql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s mssql) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) - return count > 0 -} - -func (s mssql) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) - return count > 0 -} - -func (s mssql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) - return -} - -func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { - if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit > 0 { - sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) - } - } - if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset > 0 { - sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) - } - } - return -} - -func (mssql) SelectFromDummyTable() string { - return "" -} - -func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go deleted file mode 100644 index 9deba48a..00000000 --- a/dialects/mysql/mysql.go +++ /dev/null @@ -1,3 +0,0 @@ -package mysql - -import _ "github.com/go-sql-driver/mysql" diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go deleted file mode 100644 index adeeec7b..00000000 --- a/dialects/postgres/postgres.go +++ /dev/null @@ -1,54 +0,0 @@ -package postgres - -import ( - "database/sql" - "database/sql/driver" - - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" -) - -type Hstore map[string]*string - -// Value get value of Hstore -func (h Hstore) Value() (driver.Value, error) { - hstore := hstore.Hstore{Map: map[string]sql.NullString{}} - if len(h) == 0 { - return nil, nil - } - - for key, value := range h { - var s sql.NullString - if value != nil { - s.String = *value - s.Valid = true - } - hstore.Map[key] = s - } - return hstore.Value() -} - -// Scan scan value into Hstore -func (h *Hstore) Scan(value interface{}) error { - hstore := hstore.Hstore{} - - if err := hstore.Scan(value); err != nil { - return err - } - - if len(hstore.Map) == 0 { - return nil - } - - *h = Hstore{} - for k := range hstore.Map { - if hstore.Map[k].Valid { - s := hstore.Map[k].String - (*h)[k] = &s - } else { - (*h)[k] = nil - } - } - - return nil -} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go deleted file mode 100644 index 069ad3a9..00000000 --- a/dialects/sqlite/sqlite.go +++ /dev/null @@ -1,3 +0,0 @@ -package sqlite - -import _ "github.com/mattn/go-sqlite3" diff --git a/embedded_struct_test.go b/embedded_struct_test.go deleted file mode 100644 index 7be75d99..00000000 --- a/embedded_struct_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package gorm_test - -import "testing" - -type BasePost struct { - Id int64 - Title string - URL string -} - -type HNPost struct { - BasePost - Upvotes int32 -} - -type EngadgetPost struct { - BasePost BasePost `gorm:"embedded"` - ImageUrl string -} - -func TestSaveAndQueryEmbeddedStruct(t *testing.T) { - DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) - DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) - var news HNPost - if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if news.Title != "hn_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) - var egNews EngadgetPost - if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if egNews.BasePost.Title != "engadget_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - if DB.NewScope(&HNPost{}).PrimaryField() == nil { - t.Errorf("primary key with embedded struct should works") - } - - for _, field := range DB.NewScope(&HNPost{}).Fields() { - if field.Name == "BasePost" { - t.Errorf("scope Fields should not contain embedded struct") - } - } -} diff --git a/errors.go b/errors.go index ce3a25c0..57e3fc5e 100644 --- a/errors.go +++ b/errors.go @@ -2,57 +2,49 @@ package gorm import ( "errors" - "strings" + + "gorm.io/gorm/logger" ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct - ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") + // ErrRecordNotFound record not found error + ErrRecordNotFound = logger.ErrRecordNotFound // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") - // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` - ErrCantStartTransaction = errors.New("can't start transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") + ErrInvalidTransaction = errors.New("invalid transaction") + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("not implemented") + // ErrMissingWhereClause missing where clause + ErrMissingWhereClause = errors.New("WHERE conditions required") + // ErrUnsupportedRelation unsupported relations + ErrUnsupportedRelation = errors.New("unsupported relations") + // ErrPrimaryKeyRequired primary keys required + ErrPrimaryKeyRequired = errors.New("primary key required") + // ErrModelValueRequired model value required + ErrModelValueRequired = errors.New("model value required") + // ErrModelAccessibleFieldsRequired model accessible fields required + ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required") + // ErrSubQueryRequired sub query required + ErrSubQueryRequired = errors.New("sub query required") + // ErrInvalidData unsupported data + ErrInvalidData = errors.New("unsupported data") + // ErrUnsupportedDriver unsupported driver + ErrUnsupportedDriver = errors.New("unsupported driver") + // ErrRegistered registered + ErrRegistered = errors.New("registered") + // ErrInvalidField invalid field + ErrInvalidField = errors.New("invalid field") + // ErrEmptySlice empty slice found + ErrEmptySlice = errors.New("empty slice found") + // ErrDryRunModeUnsupported dry run mode unsupported + ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") + // ErrInvalidDB invalid db + ErrInvalidDB = errors.New("invalid db") + // ErrInvalidValue invalid value + ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") + // ErrInvalidValueOfLength invalid values do not match length + ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") + // ErrPreloadNotAllowed preload is not allowed when count is used + ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") + // ErrDuplicatedKey occurs when there is a unique key constraint violation + ErrDuplicatedKey = errors.New("duplicated key not allowed") ) - -type errorsInterface interface { - GetErrors() []error -} - -// Errors contains all happened errors -type Errors struct { - errors []error -} - -// GetErrors get all happened errors -func (errs Errors) GetErrors() []error { - return errs.errors -} - -// Add add an error -func (errs *Errors) Add(err error) { - if errors, ok := err.(errorsInterface); ok { - for _, err := range errors.GetErrors() { - errs.Add(err) - } - } else { - for _, e := range errs.errors { - if err == e { - return - } - } - errs.errors = append(errs.errors, err) - } -} - -// Error format happened errors -func (errs Errors) Error() string { - var errors = []string{} - for _, e := range errs.errors { - errors = append(errors, e.Error()) - } - return strings.Join(errors, "; ") -} diff --git a/field.go b/field.go deleted file mode 100644 index 11c410b0..00000000 --- a/field.go +++ /dev/null @@ -1,58 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" -) - -// Field model field definition -type Field struct { - *StructField - IsBlank bool - Field reflect.Value -} - -// Set set a value to the field -func (field *Field) Set(value interface{}) (err error) { - if !field.Field.IsValid() { - return errors.New("field value not valid") - } - - if !field.Field.CanAddr() { - return ErrUnaddressable - } - - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - fieldValue := field.Field - if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else { - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.Struct.Type.Elem())) - } - fieldValue = fieldValue.Elem() - } - - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err = scanner.Scan(reflectValue.Interface()) - } else { - err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) - } - } - } else { - field.Field.Set(reflect.Zero(field.Field.Type())) - } - - field.IsBlank = isBlank(field.Field) - return err -} diff --git a/field_test.go b/field_test.go deleted file mode 100644 index 30e9a778..00000000 --- a/field_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -type CalculateField struct { - gorm.Model - Name string - Children []CalculateFieldChild - Category CalculateFieldCategory - EmbeddedField -} - -type EmbeddedField struct { - EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"` -} - -type CalculateFieldChild struct { - gorm.Model - CalculateFieldID uint - Name string -} - -type CalculateFieldCategory struct { - gorm.Model - CalculateFieldID uint - Name string -} - -func TestCalculateField(t *testing.T) { - var field CalculateField - var scope = DB.NewScope(&field) - if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("embedded_name"); !ok { - t.Errorf("should find embedded field") - } else if _, ok := field.TagSettings["NOT NULL"]; !ok { - t.Errorf("should find embedded field's tag settings") - } -} diff --git a/finisher_api.go b/finisher_api.go new file mode 100644 index 00000000..0e26f181 --- /dev/null +++ b/finisher_api.go @@ -0,0 +1,751 @@ +package gorm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "sync/atomic" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Create inserts value, returning the inserted data's primary key in value's id +func (db *DB) Create(value interface{}) (tx *DB) { + if db.CreateBatchSize > 0 { + return db.CreateInBatches(value, db.CreateBatchSize) + } + + tx = db.getInstance() + tx.Statement.Dest = value + return tx.callbacks.Create().Execute(tx) +} + +// CreateInBatches inserts value in batches of batchSize +func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var rowsAffected int64 + tx = db.getInstance() + + // the reflection length judgment of the optimized value + reflectLen := reflectValue.Len() + + callFc := func(tx *DB) error { + for i := 0; i < reflectLen; i += batchSize { + ends := i + batchSize + if ends > reflectLen { + ends = reflectLen + } + + subtx := tx.getInstance() + subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface() + subtx.callbacks.Create().Execute(subtx) + if subtx.Error != nil { + return subtx.Error + } + rowsAffected += subtx.RowsAffected + } + return nil + } + + if tx.SkipDefaultTransaction || reflectLen <= batchSize { + tx.AddError(callFc(tx.Session(&Session{}))) + } else { + tx.AddError(tx.Transaction(callFc)) + } + + tx.RowsAffected = rowsAffected + default: + tx = db.getInstance() + tx.Statement.Dest = value + tx = tx.callbacks.Create().Execute(tx) + } + return +} + +// Save updates value in database. If value doesn't contain a matching primary key, value is inserted. +func (db *DB) Save(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = value + + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { + reflectValue = reflect.Indirect(reflectValue) + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { + tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) + } + tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true)) + case reflect.Struct: + if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { + for _, pf := range tx.Statement.Schema.PrimaryFields { + if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero { + return tx.callbacks.Create().Execute(tx) + } + } + } + + fallthrough + default: + selectedUpdate := len(tx.Statement.Selects) != 0 + // when updating, use all fields including those zero-value fields + if !selectedUpdate { + tx.Statement.Selects = append(tx.Statement.Selects, "*") + } + + updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) + + if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { + return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(value) + } + + return updateTx + } + + return +} + +// First finds the first record ordered by primary key, matching given conditions conds +func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +// Take finds the first record returned by the database in no specified order, matching given conditions conds +func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1) + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +// Last finds the last record ordered by primary key, matching given conditions conds +func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Desc: true, + }) + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +// Find finds all records matching given conditions conds +func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +// FindInBatches finds all records in batches of batchSize +func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + var ( + tx = db.Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }).Session(&Session{}) + queryDB = tx + rowsAffected int64 + batch int + ) + + // user specified offset or limit + var totalSize int + if c, ok := tx.Statement.Clauses["LIMIT"]; ok { + if limit, ok := c.Expression.(clause.Limit); ok { + if limit.Limit != nil { + totalSize = *limit.Limit + } + + if totalSize > 0 && batchSize > totalSize { + batchSize = totalSize + } + + // reset to offset to 0 in next batch + tx = tx.Offset(-1).Session(&Session{}) + } + } + + for { + result := queryDB.Limit(batchSize).Find(dest) + rowsAffected += result.RowsAffected + batch++ + + if result.Error == nil && result.RowsAffected != 0 { + fcTx := result.Session(&Session{NewDB: true}) + fcTx.RowsAffected = result.RowsAffected + tx.AddError(fc(fcTx, batch)) + } else if result.Error != nil { + tx.AddError(result.Error) + } + + if tx.Error != nil || int(result.RowsAffected) < batchSize { + break + } + + if totalSize > 0 { + if totalSize <= int(rowsAffected) { + break + } + if totalSize/batchSize == batch { + batchSize = totalSize % batchSize + } + } + + // Optimize for-break + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + if result.Statement.Schema.PrioritizedPrimaryField == nil { + tx.AddError(ErrPrimaryKeyRequired) + break + } + + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) + } + + tx.RowsAffected = rowsAffected + return tx +} + +func (db *DB) assignInterfacesToValue(values ...interface{}) { + for _, value := range values { + switch v := value.(type) { + case []clause.Expression: + for _, expr := range v { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := db.Statement.Schema.LookUpField(column); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) + } + case clause.Column: + if field := db.Statement.Schema.LookUpField(column.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) + } + } + } else if andCond, ok := expr.(clause.AndConditions); ok { + db.assignInterfacesToValue(andCond.Exprs) + } + } + case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: + if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) + } + default: + if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + for _, f := range s.Fields { + if f.Readable { + if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero { + if field := db.Statement.Schema.LookUpField(f.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v)) + } + } + } + } + } + } else if len(values) > 0 { + if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) + } + return + } + } + } +} + +// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. +// Each conds must be a struct or map. +// +// FirstOrInit never modifies the database. It is often used with Assign and Attrs. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + tx.assignInterfacesToValue(tx.Statement.attrs...) + } + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + tx.assignInterfacesToValue(tx.Statement.assigns...) + } + return +} + +// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. +// Each conds must be a struct or map. +// +// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists. +// +// // assign an email if the record is not found +// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// // result.RowsAffected -> 1 +// +// // assign email regardless of if record is found +// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// // result.RowsAffected -> 1 +func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + result := queryTx.Find(dest, conds...) + if result.Error != nil { + tx.Error = result.Error + return tx + } + + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + result.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) + } + + // initialize with attrs, conds + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) + } + + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + } + } + } + + return tx.Model(dest).Updates(assigns) + } + + return tx +} + +// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields +func (db *DB) Update(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + return tx.callbacks.Update().Execute(tx) +} + +// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields +func (db *DB) Updates(values interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = values + return tx.callbacks.Update().Execute(tx) +} + +func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.Statement.SkipHooks = true + return tx.callbacks.Update().Execute(tx) +} + +func (db *DB) UpdateColumns(values interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = values + tx.Statement.SkipHooks = true + return tx.callbacks.Update().Execute(tx) +} + +// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If +// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current +// time if null. +func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.Dest = value + return tx.callbacks.Delete().Execute(tx) +} + +func (db *DB) Count(count *int64) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Model == nil { + tx.Statement.Model = tx.Statement.Dest + defer func() { + tx.Statement.Model = nil + }() + } + + if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { + defer func() { + tx.Statement.Clauses["SELECT"] = selectClause + }() + } else { + defer delete(tx.Statement.Clauses, "SELECT") + } + + if len(tx.Statement.Selects) == 0 { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}}) + } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") { + expr := clause.Expr{SQL: "count(*)"} + + if len(tx.Statement.Selects) == 1 { + dbName := tx.Statement.Selects[0] + fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName + } + } + + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else if dbName != "*" { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } + } + } + + tx.Statement.AddClause(clause.Select{Expression: expr}) + } + + if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { + if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { + delete(tx.Statement.Clauses, "ORDER BY") + defer func() { + tx.Statement.Clauses["ORDER BY"] = orderByClause + }() + } + } + + tx.Statement.Dest = count + tx = tx.callbacks.Query().Execute(tx) + + if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { + *count = tx.RowsAffected + } + + return +} + +func (db *DB) Row() *sql.Row { + tx := db.getInstance().Set("rows", false) + tx = tx.callbacks.Row().Execute(tx) + row, ok := tx.Statement.Dest.(*sql.Row) + if !ok && tx.DryRun { + db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) + } + return row +} + +func (db *DB) Rows() (*sql.Rows, error) { + tx := db.getInstance().Set("rows", true) + tx = tx.callbacks.Row().Execute(tx) + rows, ok := tx.Statement.Dest.(*sql.Rows) + if !ok && tx.DryRun && tx.Error == nil { + tx.Error = ErrDryRunModeUnsupported + } + return rows, tx.Error +} + +// Scan scans selected value to the struct dest +func (db *DB) Scan(dest interface{}) (tx *DB) { + config := *db.Config + currentLogger, newLogger := config.Logger, logger.Recorder.New() + config.Logger = newLogger + + tx = db.getInstance() + tx.Config = &config + + if rows, err := tx.Rows(); err == nil { + if rows.Next() { + tx.ScanRows(rows, dest) + } else { + tx.RowsAffected = 0 + } + tx.AddError(rows.Close()) + } + + currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { + return newLogger.SQL, tx.RowsAffected + }, tx.Error) + tx.Logger = currentLogger + return +} + +// Pluck queries a single column from a model, returning in the slice dest. E.g.: +// +// var ages []int64 +// db.Model(&users).Pluck("age", &ages) +func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Model != nil { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(column); f != nil { + column = f.DBName + } + } + } + + if len(tx.Statement.Selects) != 1 { + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, + }) + } + tx.Statement.Dest = dest + return tx.callbacks.Query().Execute(tx) +} + +func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { + tx := db.getInstance() + if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { + tx.AddError(err) + } + tx.Statement.Dest = dest + tx.Statement.ReflectValue = reflect.ValueOf(dest) + for tx.Statement.ReflectValue.Kind() == reflect.Ptr { + elem := tx.Statement.ReflectValue.Elem() + if !elem.IsValid() { + elem = reflect.New(tx.Statement.ReflectValue.Type().Elem()) + tx.Statement.ReflectValue.Set(elem) + } + tx.Statement.ReflectValue = elem + } + Scan(rows, tx, ScanInitialized) + return tx.Error +} + +// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is +// returned to the connection pool. +func (db *DB) Connection(fc func(tx *DB) error) (err error) { + if db.Error != nil { + return db.Error + } + + tx := db.getInstance() + sqlDB, err := tx.DB() + if err != nil { + return + } + + conn, err := sqlDB.Conn(tx.Statement.Context) + if err != nil { + return + } + + defer conn.Close() + tx.Statement.ConnPool = conn + return fc(tx) +} + +var ( + savepointIdx int64 + savepointNamePool = &sync.Pool{ + New: func() interface{} { + return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1)) + }, + } +) + +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an +// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs +// they are rolled back. +func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { + panicked := true + + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + // nested transaction + if !db.DisableNestedTransaction { + poolName := savepointNamePool.Get() + defer savepointNamePool.Put(poolName) + err = db.SavePoint(poolName.(string)).Error + if err != nil { + return + } + + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + db.RollbackTo(poolName.(string)) + } + }() + } + err = fc(db.Session(&Session{NewDB: db.clone == 1})) + } else { + tx := db.Begin(opts...) + if tx.Error != nil { + return tx.Error + } + + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + if err = fc(tx); err == nil { + panicked = false + return tx.Commit().Error + } + } + + panicked = false + return +} + +// Begin begins a transaction with any transaction options opts +func (db *DB) Begin(opts ...*sql.TxOptions) *DB { + var ( + // clone statement + tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) + opt *sql.TxOptions + err error + ) + + if len(opts) > 0 { + opt = opts[0] + } + + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + case ConnPoolBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + default: + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx +} + +// Commit commits the changes in a transaction +func (db *DB) Commit() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Commit()) + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + +// Rollback rollbacks the changes in a transaction +func (db *DB) Rollback() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Rollback()) + } + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + +func (db *DB) SavePoint(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + db.AddError(savePointer.SavePoint(db, name)) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +func (db *DB) RollbackTo(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + db.AddError(savePointer.RollbackTo(db, name)) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +// Exec executes raw sql +func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + + return tx.callbacks.Raw().Execute(tx) +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..85e4242a --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module gorm.io/gorm + +go 1.16 + +require ( + github.com/jinzhu/inflection v1.0.0 + github.com/jinzhu/now v1.1.5 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..bd6104c9 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/gorm.go b/gorm.go new file mode 100644 index 00000000..07a913fc --- /dev/null +++ b/gorm.go @@ -0,0 +1,496 @@ +package gorm + +import ( + "context" + "database/sql" + "fmt" + "sort" + "sync" + "time" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +// for Config.cacheStore store PreparedStmtDB key +const preparedStmtDBKey = "preparedStmt" + +// Config GORM config +type Config struct { + // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity + // You can disable it by setting `SkipDefaultTransaction` to true + SkipDefaultTransaction bool + // NamingStrategy tables, columns naming strategy + NamingStrategy schema.Namer + // FullSaveAssociations full save associations + FullSaveAssociations bool + // Logger + Logger logger.Interface + // NowFunc the function to be used when creating a new timestamp + NowFunc func() time.Time + // DryRun generate sql without execute + DryRun bool + // PrepareStmt executes the given query in cached statement + PrepareStmt bool + // DisableAutomaticPing + DisableAutomaticPing bool + // DisableForeignKeyConstraintWhenMigrating + DisableForeignKeyConstraintWhenMigrating bool + // IgnoreRelationshipsWhenMigrating + IgnoreRelationshipsWhenMigrating bool + // DisableNestedTransaction disable nested transaction + DisableNestedTransaction bool + // AllowGlobalUpdate allow global update + AllowGlobalUpdate bool + // QueryFields executes the SQL query with all fields of the table + QueryFields bool + // CreateBatchSize default create batch size + CreateBatchSize int + // TranslateError enabling error translation + TranslateError bool + + // ClauseBuilders clause builder + ClauseBuilders map[string]clause.ClauseBuilder + // ConnPool db conn pool + ConnPool ConnPool + // Dialector database dialector + Dialector + // Plugins registered plugins + Plugins map[string]Plugin + + callbacks *callbacks + cacheStore *sync.Map +} + +// Apply update config to new config +func (c *Config) Apply(config *Config) error { + if config != c { + *config = *c + } + return nil +} + +// AfterInitialize initialize plugins after db connected +func (c *Config) AfterInitialize(db *DB) error { + if db != nil { + for _, plugin := range c.Plugins { + if err := plugin.Initialize(db); err != nil { + return err + } + } + } + return nil +} + +// Option gorm option interface +type Option interface { + Apply(*Config) error + AfterInitialize(*DB) error +} + +// DB GORM DB definition +type DB struct { + *Config + Error error + RowsAffected int64 + Statement *Statement + clone int +} + +// Session session config when create session with Session() method +type Session struct { + DryRun bool + PrepareStmt bool + NewDB bool + Initialized bool + SkipHooks bool + SkipDefaultTransaction bool + DisableNestedTransaction bool + AllowGlobalUpdate bool + FullSaveAssociations bool + QueryFields bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time + CreateBatchSize int +} + +// Open initialize db session based on dialector +func Open(dialector Dialector, opts ...Option) (db *DB, err error) { + config := &Config{} + + sort.Slice(opts, func(i, j int) bool { + _, isConfig := opts[i].(*Config) + _, isConfig2 := opts[j].(*Config) + return isConfig && !isConfig2 + }) + + for _, opt := range opts { + if opt != nil { + if applyErr := opt.Apply(config); applyErr != nil { + return nil, applyErr + } + defer func(opt Option) { + if errr := opt.AfterInitialize(db); errr != nil { + err = errr + } + }(opt) + } + } + + if d, ok := dialector.(interface{ Apply(*Config) error }); ok { + if err = d.Apply(config); err != nil { + return + } + } + + if config.NamingStrategy == nil { + config.NamingStrategy = schema.NamingStrategy{} + } + + if config.Logger == nil { + config.Logger = logger.Default + } + + if config.NowFunc == nil { + config.NowFunc = func() time.Time { return time.Now().Local() } + } + + if dialector != nil { + config.Dialector = dialector + } + + if config.Plugins == nil { + config.Plugins = map[string]Plugin{} + } + + if config.cacheStore == nil { + config.cacheStore = &sync.Map{} + } + + db = &DB{Config: config, clone: 1} + + db.callbacks = initializeCallbacks(db) + + if config.ClauseBuilders == nil { + config.ClauseBuilders = map[string]clause.ClauseBuilder{} + } + + if config.Dialector != nil { + err = config.Dialector.Initialize(db) + + if err != nil { + if db, err := db.DB(); err == nil { + _ = db.Close() + } + } + } + + preparedStmt := &PreparedStmtDB{ + ConnPool: db.ConnPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), + } + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) + + if config.PrepareStmt { + db.ConnPool = preparedStmt + } + + db.Statement = &Statement{ + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, + } + + if err == nil && !config.DisableAutomaticPing { + if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { + err = pinger.Ping() + } + } + + if err != nil { + config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) + } + + return +} + +// Session create new db session +func (db *DB) Session(config *Session) *DB { + var ( + txConfig = *db.Config + tx = &DB{ + Config: &txConfig, + Statement: db.Statement, + Error: db.Error, + clone: 1, + } + ) + if config.CreateBatchSize > 0 { + tx.Config.CreateBatchSize = config.CreateBatchSize + } + + if config.SkipDefaultTransaction { + tx.Config.SkipDefaultTransaction = true + } + + if config.AllowGlobalUpdate { + txConfig.AllowGlobalUpdate = true + } + + if config.FullSaveAssociations { + txConfig.FullSaveAssociations = true + } + + if config.Context != nil || config.PrepareStmt || config.SkipHooks { + tx.Statement = tx.Statement.clone() + tx.Statement.DB = tx + } + + if config.Context != nil { + tx.Statement.Context = config.Context + } + + if config.PrepareStmt { + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { + preparedStmt := v.(*PreparedStmtDB) + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } + } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true + } + } + + if config.SkipHooks { + tx.Statement.SkipHooks = true + } + + if config.DisableNestedTransaction { + txConfig.DisableNestedTransaction = true + } + + if !config.NewDB { + tx.clone = 2 + } + + if config.DryRun { + tx.Config.DryRun = true + } + + if config.QueryFields { + tx.Config.QueryFields = true + } + + if config.Logger != nil { + tx.Config.Logger = config.Logger + } + + if config.NowFunc != nil { + tx.Config.NowFunc = config.NowFunc + } + + if config.Initialized { + tx = tx.getInstance() + } + + return tx +} + +// WithContext change current instance db's context to ctx +func (db *DB) WithContext(ctx context.Context) *DB { + return db.Session(&Session{Context: ctx}) +} + +// Debug start debug mode +func (db *DB) Debug() (tx *DB) { + tx = db.getInstance() + return tx.Session(&Session{ + Logger: db.Logger.LogMode(logger.Info), + }) +} + +// Set store value with key into current db instance's context +func (db *DB) Set(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(key, value) + return tx +} + +// Get get value with key from current db instance's context +func (db *DB) Get(key string) (interface{}, bool) { + return db.Statement.Settings.Load(key) +} + +// InstanceSet store value with key into current db instance's context +func (db *DB) InstanceSet(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) + return tx +} + +// InstanceGet get value with key from current db instance's context +func (db *DB) InstanceGet(key string) (interface{}, bool) { + return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) +} + +// Callback returns callback manager +func (db *DB) Callback() *callbacks { + return db.callbacks +} + +// AddError add error to db +func (db *DB) AddError(err error) error { + if err != nil { + if db.Config.TranslateError { + if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { + err = errTranslator.Translate(err) + } + } + + if db.Error == nil { + db.Error = err + } else { + db.Error = fmt.Errorf("%v; %w", db.Error, err) + } + } + return db.Error +} + +// DB returns `*sql.DB` +func (db *DB) DB() (*sql.DB, error) { + connPool := db.ConnPool + + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + + if sqldb, ok := connPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, ErrInvalidDB +} + +func (db *DB) getInstance() *DB { + if db.clone > 0 { + tx := &DB{Config: db.Config, Error: db.Error} + + if db.clone == 1 { + // clone with new statement + tx.Statement = &Statement{ + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + } + } else { + // with clone statement + tx.Statement = db.Statement.clone() + tx.Statement.DB = tx + } + + return tx + } + + return db +} + +// Expr returns clause.Expr, which can be used to pass SQL expression as params +func Expr(expr string, args ...interface{}) clause.Expr { + return clause.Expr{SQL: expr, Vars: args} +} + +// SetupJoinTable setup join table schema +func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { + var ( + tx = db.getInstance() + stmt = tx.Statement + modelSchema, joinSchema *schema.Schema + ) + + err := stmt.Parse(model) + if err != nil { + return err + } + modelSchema = stmt.Schema + + err = stmt.Parse(joinTable) + if err != nil { + return err + } + joinSchema = stmt.Schema + + relation, ok := modelSchema.Relationships.Relations[field] + isRelation := ok && relation.JoinTable != nil + if !isRelation { + return fmt.Errorf("failed to find relation: %s", field) + } + + for _, ref := range relation.References { + f := joinSchema.LookUpField(ref.ForeignKey.DBName) + if f == nil { + return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) + } + + f.DataType = ref.ForeignKey.DataType + f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size + } + ref.ForeignKey = f + } + + for name, rel := range relation.JoinTable.Relationships.Relations { + if _, ok := joinSchema.Relationships.Relations[name]; !ok { + rel.Schema = joinSchema + joinSchema.Relationships.Relations[name] = rel + } + } + relation.JoinTable = joinSchema + + return nil +} + +// Use use plugin +func (db *DB) Use(plugin Plugin) error { + name := plugin.Name() + if _, ok := db.Plugins[name]; ok { + return ErrRegistered + } + if err := plugin.Initialize(db); err != nil { + return err + } + db.Plugins[name] = plugin + return nil +} + +// ToSQL for generate SQL string. +// +// db.ToSQL(func(tx *gorm.DB) *gorm.DB { +// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) +// .Limit(10).Offset(5) +// .Order("name ASC") +// .First(&User{}) +// }) +func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { + tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) + stmt := tx.Statement + + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) +} diff --git a/interface.go b/interface.go deleted file mode 100644 index 7b02aa66..00000000 --- a/interface.go +++ /dev/null @@ -1,19 +0,0 @@ -package gorm - -import "database/sql" - -type sqlCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type sqlDb interface { - Begin() (*sql.Tx, error) -} - -type sqlTx interface { - Commit() error - Rollback() error -} diff --git a/interfaces.go b/interfaces.go new file mode 100644 index 00000000..cf3d0a3d --- /dev/null +++ b/interfaces.go @@ -0,0 +1,102 @@ +package gorm + +import ( + "context" + "database/sql" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// Dialector GORM database dialector +// 实现数据库的驱动程序和数据库方言。 +type Dialector interface { + // Name 返回使用该 Dialector 实例连接的数据库类型的名称,例如 "mysql"、"sqlite" 等。 + Name() string + // Initialize 用于初始化连接到数据库的 *DB 实例。此方法将在 Open 方法中调用。 + Initialize(*DB) error + // Migrator 返回用于执行数据库迁移的 Migrator 接口实例, 用于管理数据库迁移。该接口主要用于执行和管理数据模型和数据表之间的映射关系。 + Migrator(db *DB) Migrator + // DataTypeOf 返回给定 schema.Field 类型的数据库原生数据类型, 该方法通常在需要映射数据模型和数据库类型时使用。。例如,schema.Field 类型 string 可能映射到数据库中的 VARCHAR 类型。 + DataTypeOf(*schema.Field) string + // DefaultValueOf 返回给定 schema.Field 类型的默认值表达式, 该方法通常在需要设置数据模型字段默认值时使用。如果该字段没有默认值,则返回 nil。 + DefaultValueOf(*schema.Field) clause.Expression + // BindVarTo BindVarTo 将给定的值绑定到 SQL 语句中的占位符。该方法通常用于构建动态 SQL 语句。 + // 例如,BindVarTo 方法可以将值 42 绑定到 SQL 语句 SELECT * FROM users WHERE age = ? 中的 ? 占位符上。 + BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) + // QuoteTo 将给定的标识符(例如表名、列名等)引用为数据库原生的语法, 该方法通常用于保证 SQL 语句的安全性和正确性。。例如,在 MySQL 中引用表名 users 可能需要将其引用为 `users`。 + QuoteTo(clause.Writer, string) + // Explain 生成一条解释 SQL 执行计划的 SQL 语句。该方法通常用于优化数据库的查询性能和调试 SQL 语句。 + Explain(sql string, vars ...interface{}) string +} + +// Plugin GORM plugin interface +type Plugin interface { + Name() string + Initialize(*DB) error +} + +type ParamsFilter interface { + ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) +} + +// ConnPool db conns pool interface +type ConnPool interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// SavePointerDialectorInterface save pointer interface +type SavePointerDialectorInterface interface { + SavePoint(tx *DB, name string) error + RollbackTo(tx *DB, name string) error +} + +// TxBeginner tx beginner +type TxBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// ConnPoolBeginner conn pool beginner +type ConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) +} + +// TxCommitter tx committer +type TxCommitter interface { + Commit() error + Rollback() error +} + +// Tx sql.Tx interface +type Tx interface { + ConnPool + TxCommitter + StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt +} + +// Valuer gorm valuer interface +type Valuer interface { + GormValue(context.Context, *DB) clause.Expr +} + +// GetDBConnector SQL db connector +type GetDBConnector interface { + GetDBConn() (*sql.DB, error) +} + +// Rows rows interface +type Rows interface { + Columns() ([]string, error) + ColumnTypes() ([]*sql.ColumnType, error) + Next() bool + Scan(dest ...interface{}) error + Err() error + Close() error +} + +type ErrorTranslator interface { + Translate(err error) error +} diff --git a/join_table_handler.go b/join_table_handler.go deleted file mode 100644 index 18c12a85..00000000 --- a/join_table_handler.go +++ /dev/null @@ -1,204 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// JoinTableHandlerInterface is an interface for how to handle many2many relations -type JoinTableHandlerInterface interface { - // initialize join table handler - Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) - // Table return join table's table name - Table(db *DB) string - // Add create relationship in join table for source and destination - Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error - // Delete delete relationship in join table for sources - Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error - // JoinWith query with `Join` conditions - JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - // SourceForeignKeys return source foreign keys - SourceForeignKeys() []JoinTableForeignKey - // DestinationForeignKeys return destination foreign keys - DestinationForeignKeys() []JoinTableForeignKey -} - -// JoinTableForeignKey join table foreign key struct -type JoinTableForeignKey struct { - DBName string - AssociationDBName string -} - -// JoinTableSource is a struct that contains model type and foreign keys -type JoinTableSource struct { - ModelType reflect.Type - ForeignKeys []JoinTableForeignKey -} - -// JoinTableHandler default join table handler -type JoinTableHandler struct { - TableName string `sql:"-"` - Source JoinTableSource `sql:"-"` - Destination JoinTableSource `sql:"-"` -} - -// SourceForeignKeys return source foreign keys -func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { - return s.Source.ForeignKeys -} - -// DestinationForeignKeys return destination foreign keys -func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { - return s.Destination.ForeignKeys -} - -// Setup initialize a default join table handler -func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { - s.TableName = tableName - - s.Source = JoinTableSource{ModelType: source} - for idx, dbName := range relationship.ForeignFieldNames { - s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBNames[idx], - AssociationDBName: dbName, - }) - } - - s.Destination = JoinTableSource{ModelType: destination} - for idx, dbName := range relationship.AssociationForeignFieldNames { - s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBNames[idx], - AssociationDBName: dbName, - }) - } -} - -// Table return join table's table name -func (s JoinTableHandler) Table(db *DB) string { - return s.TableName -} - -func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { - values := map[string]interface{}{} - - for _, source := range sources { - scope := db.NewScope(source) - modelType := scope.GetModelStruct().ModelType - - if s.Source.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } - } - return values -} - -// Add create relationship in join table for source and destination -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - scope := db.NewScope("") - searchMap := s.getSearchMap(db, source, destination) - - var assignColumns, binVars, conditions []string - var values []interface{} - for key, value := range searchMap { - assignColumns = append(assignColumns, scope.Quote(key)) - binVars = append(binVars, `?`) - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - for _, value := range values { - values = append(values, value) - } - - quotedTable := scope.Quote(handler.Table(db)) - sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", - quotedTable, - strings.Join(assignColumns, ","), - strings.Join(binVars, ","), - scope.Dialect().SelectFromDummyTable(), - quotedTable, - strings.Join(conditions, " AND "), - ) - - return db.Exec(sql, values...).Error -} - -// Delete delete relationship in join table for sources -func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { - var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} - ) - - for key, value := range s.getSearchMap(db, sources...) { - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error -} - -// JoinWith query with `Join` conditions -func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { - var ( - scope = db.NewScope(source) - tableName = handler.Table(db) - quotedTableName = scope.Quote(tableName) - joinConditions []string - values []interface{} - ) - - if s.Source.ModelType == scope.GetModelStruct().ModelType { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() - for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - var foreignDBNames []string - var foreignFieldNames []string - - for _, foreignKey := range s.Source.ForeignKeys { - foreignDBNames = append(foreignDBNames, foreignKey.DBName) - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) - - var condString string - if len(foreignFieldValues) > 0 { - var quotedForeignDBNames []string - for _, dbName := range foreignDBNames { - quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) - } - - condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) - values = append(values, toQueryValues(keys)) - } else { - condString = fmt.Sprintf("1 <> 1") - } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). - Where(condString, toQueryValues(foreignFieldValues)...) - } - - db.Error = errors.New("wrong source type for join table handler") - return db -} diff --git a/join_table_test.go b/join_table_test.go deleted file mode 100644 index 1a83a9c8..00000000 --- a/join_table_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package gorm_test - -import ( - "fmt" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type Person struct { - Id int - Name string - Addresses []*Address `gorm:"many2many:person_addresses;"` -} - -type PersonAddress struct { - gorm.JoinTableHandler - PersonID int - AddressID int - DeletedAt *time.Time - CreatedAt time.Time -} - -func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { - return db.Where(map[string]interface{}{ - "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), - "address_id": db.NewScope(associationValue).PrimaryKeyValue(), - }).Assign(map[string]interface{}{ - "person_id": foreignValue, - "address_id": associationValue, - "deleted_at": gorm.Expr("NULL"), - }).FirstOrCreate(&PersonAddress{}).Error -} - -func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { - return db.Delete(&PersonAddress{}).Error -} - -func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { - table := pa.Table(db) - return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) -} - -func TestJoinTable(t *testing.T) { - DB.Exec("drop table person_addresses;") - DB.AutoMigrate(&Person{}) - DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) - - address1 := &Address{Address1: "address 1"} - address2 := &Address{Address1: "address 2"} - person := &Person{Name: "person", Addresses: []*Address{address1, address2}} - DB.Save(person) - - DB.Model(person).Association("Addresses").Delete(address1) - - if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { - t.Errorf("Should found one address") - } - - if DB.Model(person).Association("Addresses").Count() != 1 { - t.Errorf("Should found one address") - } - - if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { - t.Errorf("Found two addresses with Unscoped") - } - - if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 { - t.Errorf("Should deleted all addresses") - } -} diff --git a/logger.go b/logger.go deleted file mode 100644 index 4f312087..00000000 --- a/logger.go +++ /dev/null @@ -1,110 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "log" - "os" - "reflect" - "regexp" - "time" - "unicode" -) - -var ( - defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - sqlRegexp = regexp.MustCompile(`\?`) - numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) -) - -type logger interface { - Print(v ...interface{}) -} - -// LogWriter log writer interface -type LogWriter interface { - Println(v ...interface{}) -} - -// Logger default logger -type Logger struct { - LogWriter -} - -// Print format & print log -func (logger Logger) Print(values ...interface{}) { - if len(values) > 1 { - level := values[0] - currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) - messages := []interface{}{source, currentTime} - - if level == "sql" { - // duration - messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) - // sql - var sql string - var formattedValues []string - - for _, value := range values[4].([]interface{}) { - indirectValue := reflect.Indirect(reflect.ValueOf(value)) - if indirectValue.IsValid() { - value = indirectValue.Interface() - if t, ok := value.(time.Time); ok { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339))) - } else if b, ok := value.([]byte); ok { - if str := string(b); isPrintable(str) { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) - } else { - formattedValues = append(formattedValues, "''") - } - } else if r, ok := value.(driver.Valuer); ok { - if value, err := r.Value(); err == nil && value != nil { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } else { - formattedValues = append(formattedValues, "NULL") - } - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } - - // differentiate between $n placeholders or else treat like ? - if numericPlaceHolderRegexp.MatchString(values[3].(string)) { - sql = values[3].(string) - for index, value := range formattedValues { - placeholder := fmt.Sprintf(`\$%d`, index+1) - subre := regexp.MustCompile(placeholder) - sql = subre.ReplaceAllString(sql, value) - } - } else { - var formattedValuesLength = len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { - sql += value - if index < formattedValuesLength { - sql += formattedValues[index] - } - } - } - - messages = append(messages, sql) - } else { - messages = append(messages, "\033[31;1m") - messages = append(messages, values[2:]...) - messages = append(messages, "\033[0m") - } - logger.Println(messages...) - } -} - -func isPrintable(s string) bool { - for _, r := range s { - if !unicode.IsPrint(r) { - return false - } - } - return true -} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 00000000..aa0060bc --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,211 @@ +package logger + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "os" + "time" + + "gorm.io/gorm/utils" +) + +// ErrRecordNotFound record not found error +var ErrRecordNotFound = errors.New("record not found") + +// Colors +const ( + Reset = "\033[0m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Magenta = "\033[35m" + Cyan = "\033[36m" + White = "\033[37m" + BlueBold = "\033[34;1m" + MagentaBold = "\033[35;1m" + RedBold = "\033[31;1m" + YellowBold = "\033[33;1m" +) + +// LogLevel log level +type LogLevel int + +const ( + // Silent silent log level + Silent LogLevel = iota + 1 + // Error error log level + Error + // Warn warn log level + Warn + // Info info log level + Info +) + +// Writer log writer interface +type Writer interface { + Printf(string, ...interface{}) +} + +// Config logger config +type Config struct { + SlowThreshold time.Duration + Colorful bool + IgnoreRecordNotFoundError bool + ParameterizedQueries bool + LogLevel LogLevel +} + +// Interface logger interface +type Interface interface { + LogMode(LogLevel) Interface + Info(context.Context, string, ...interface{}) + Warn(context.Context, string, ...interface{}) + Error(context.Context, string, ...interface{}) + Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) +} + +var ( + // Discard Discard logger will print any log to io.Discard + Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) + // Default Default logger + Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: Warn, + IgnoreRecordNotFoundError: false, + Colorful: true, + }) + // Recorder Recorder logger records running SQL into a recorder instance + Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} +) + +// New initialize logger +func New(writer Writer, config Config) Interface { + var ( + infoStr = "%s\n[info] " + warnStr = "%s\n[warn] " + errStr = "%s\n[error] " + traceStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" + ) + + if config.Colorful { + infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset + errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + } + + return &logger{ + Writer: writer, + Config: config, + infoStr: infoStr, + warnStr: warnStr, + errStr: errStr, + traceStr: traceStr, + traceWarnStr: traceWarnStr, + traceErrStr: traceErrStr, + } +} + +type logger struct { + Writer + Config + infoStr, warnStr, errStr string + traceStr, traceErrStr, traceWarnStr string +} + +// LogMode log mode +func (l *logger) LogMode(level LogLevel) Interface { + newlogger := *l + newlogger.LogLevel = level + return &newlogger +} + +// Info print info +func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Warn print warn messages +func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Error print error messages +func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Trace print sql message +func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + switch { + case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case l.LogLevel == Info: + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + } +} + +// Trace print sql message +func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Config.ParameterizedQueries { + return sql, nil + } + return sql, params +} + +type traceRecorder struct { + Interface + BeginAt time.Time + SQL string + RowsAffected int64 + Err error +} + +// New new trace recorder +func (l traceRecorder) New() *traceRecorder { + return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} +} + +// Trace implement logger interface +func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + l.BeginAt = begin + l.SQL, l.RowsAffected = fc() + l.Err = err +} diff --git a/logger/sql.go b/logger/sql.go new file mode 100644 index 00000000..bcacc7cf --- /dev/null +++ b/logger/sql.go @@ -0,0 +1,158 @@ +package logger + +import ( + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "strconv" + "strings" + "time" + "unicode" + + "gorm.io/gorm/utils" +) + +const ( + tmFmtWithMS = "2006-01-02 15:04:05.999" + tmFmtZero = "0000-00-00 00:00:00" + nullStr = "NULL" +) + +func isPrintable(s string) bool { + for _, r := range s { + if !unicode.IsPrint(r) { + return false + } + } + return true +} + +var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} + +var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) + +// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { + var ( + convertParams func(interface{}, int) + vars = make([]string, len(avars)) + ) + + convertParams = func(v interface{}, idx int) { + switch v := v.(type) { + case bool: + vars[idx] = strconv.FormatBool(v) + case time.Time: + if v.IsZero() { + vars[idx] = escaper + tmFmtZero + escaper + } else { + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper + } + case *time.Time: + if v != nil { + if v.IsZero() { + vars[idx] = escaper + tmFmtZero + escaper + } else { + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper + } + } else { + vars[idx] = nullStr + } + case driver.Valuer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + r, _ := v.Value() + convertParams(r, idx) + } else { + vars[idx] = nullStr + } + case fmt.Stringer: + reflectValue := reflect.ValueOf(v) + switch reflectValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) + case reflect.Float32, reflect.Float64: + vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) + case reflect.Bool: + vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) + case reflect.String: + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + default: + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + } else { + vars[idx] = nullStr + } + } + case []byte: + if s := string(v); isPrintable(s) { + vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper + } else { + vars[idx] = escaper + "" + escaper + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + vars[idx] = utils.ToString(v) + case float64, float32: + vars[idx] = fmt.Sprintf("%.6f", v) + case string: + vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper + default: + rv := reflect.ValueOf(v) + if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { + vars[idx] = nullStr + } else if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + convertParams(v, idx) + } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { + convertParams(reflect.Indirect(rv).Interface(), idx) + } else { + for _, t := range convertibleTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return + } + } + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper + } + } + } + + for idx, v := range avars { + convertParams(v, idx) + } + + if numericPlaceholder == nil { + var idx int + var newSQL strings.Builder + + for _, v := range []byte(sql) { + if v == '?' { + if len(vars) > idx { + newSQL.WriteString(vars[idx]) + idx++ + continue + } + } + newSQL.WriteByte(v) + } + + sql = newSQL.String() + } else { + sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") + + sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string { + num := v[1 : len(v)-1] + n, _ := strconv.Atoi(num) + + // position var start from 1 ($1, $2) + n -= 1 + if n >= 0 && n <= len(vars)-1 { + return vars[n] + } + return v + }) + } + + return sql +} diff --git a/logger/sql_test.go b/logger/sql_test.go new file mode 100644 index 00000000..c5b181a9 --- /dev/null +++ b/logger/sql_test.go @@ -0,0 +1,105 @@ +package logger_test + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "regexp" + "strings" + "testing" + + "github.com/jinzhu/now" + "gorm.io/gorm/logger" +) + +type JSON json.RawMessage + +func (j JSON) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + return json.RawMessage(j).MarshalJSON() +} + +type ExampleStruct struct { + Name string + Val string +} + +func (s ExampleStruct) Value() (driver.Value, error) { + return json.Marshal(s) +} + +func format(v []byte, escaper string) string { + return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper +} + +func TestExplainSQL(t *testing.T) { + type role string + type password []byte + var ( + tt = now.MustParse("2020-02-23 11:10:10") + myrole = role("admin") + pwd = password([]byte("pass")) + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} + ) + + results := []struct { + SQL string + NumericRegexp *regexp.Regexp + Vars []interface{} + Result string + }{ + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", + NumericRegexp: regexp.MustCompile(`@p(\d+)`), + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", + NumericRegexp: regexp.MustCompile(`\$(\d+)`), + Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", + NumericRegexp: regexp.MustCompile(`@p(\d+)`), + Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + } + + for idx, r := range results { + if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result { + t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result) + } + } +} diff --git a/main.go b/main.go deleted file mode 100644 index 04f39228..00000000 --- a/main.go +++ /dev/null @@ -1,703 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "time" -) - -// DB contains information for current db connection -type DB struct { - Value interface{} - Error error - RowsAffected int64 - callbacks *Callback - db sqlCommon - parent *DB - search *search - logMode int - logger logger - dialect Dialect - singularTable bool - source string - values map[string]interface{} - joinTableHandlers map[string]JoinTableHandler -} - -// Open initialize a new db connection, need to import driver first, e.g: -// -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } -// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (*DB, error) { - var db DB - var err error - - if len(args) == 0 { - err = errors.New("invalid database source") - } else { - var source string - var dbSQL sqlCommon - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - case sqlCommon: - source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() - dbSQL = value - } - - db = DB{ - dialect: newDialect(dialect, dbSQL.(*sql.DB)), - logger: defaultLogger, - callbacks: DefaultCallback, - source: source, - values: map[string]interface{}{}, - db: dbSQL, - } - db.parent = &db - - if err == nil { - err = db.DB().Ping() // Send a ping to make sure the database connection is alive. - } - } - - return &db, err -} - -// Close close current db connection -func (s *DB) Close() error { - return s.parent.db.(*sql.DB).Close() -} - -// DB get `*sql.DB` from current connection -func (s *DB) DB() *sql.DB { - return s.db.(*sql.DB) -} - -// New clone a new db connection without search conditions -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - clone.Value = nil - return clone -} - -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} -} - -// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() sqlCommon { - return s.db -} - -// Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) -// Refer https://jinzhu.github.io/gorm/development.html#callbacks -func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone() - return s.parent.callbacks -} - -// SetLogger replace default logger -func (s *DB) SetLogger(log logger) { - s.logger = log -} - -// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs -func (s *DB) LogMode(enable bool) *DB { - if enable { - s.logMode = 2 - } else { - s.logMode = 1 - } - return s -} - -// SingularTable use singular table by default -func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() - s.parent.singularTable = enable -} - -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query -func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.Where(query, args...).db -} - -// Or filter records that match before conditions or this one, similar to `Where` -func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.Or(query, args...).db -} - -// Not filter records that don't match current conditions, similar to `Where` -func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.Not(query, args...).db -} - -// Limit specify the number of records to be retrieved -func (s *DB) Limit(limit interface{}) *DB { - return s.clone().search.Limit(limit).db -} - -// Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset interface{}) *DB { - return s.clone().search.Offset(offset).db -} - -// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (s *DB) Order(value interface{}, reorder ...bool) *DB { - return s.clone().search.Order(value, reorder...).db -} - -// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; -// When creating/updating, specify fields that you want to save to database -func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Select(query, args...).db -} - -// Omit specify fields that you want to ignore when saving to database for creating, updating -func (s *DB) Omit(columns ...string) *DB { - return s.clone().search.Omit(columns...).db -} - -// Group specify the group method on the find -func (s *DB) Group(query string) *DB { - return s.clone().search.Group(query).db -} - -// Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query string, values ...interface{}) *DB { - return s.clone().search.Having(query, values...).db -} - -// Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { - return s.clone().search.Joins(query, args...).db -} - -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } -// -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } -// -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/curd.html#scopes -func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - s = f(s) - } - return s -} - -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete -func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db -} - -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.Attrs(attrs...).db -} - -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.Assign(attrs...).db -} - -// First find first record that match given conditions, order by primary key -func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Last find last record that match given conditions, order by primary key -func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Find find records that match given conditions -func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Scan scan value to a struct -func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db -} - -// Row return `*sql.Row` with given conditions -func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() -} - -// Rows return `*sql.Rows` with given conditions -func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() -} - -// ScanRows scan `*sql.Rows` to give struct -func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { - var ( - clone = s.clone() - scope = clone.NewScope(result) - columns, err = rows.Columns() - ) - - if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) - } - - return clone.Error -} - -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -// Count get how many records for a model -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - -// Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db -} - -// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/curd.html#firstorinit -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.NewScope(out).inlineCondition(where...).initialize() - } else { - c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) - } - return c -} - -// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/curd.html#firstorcreate -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error) - } else if len(c.search.assignAttrs) > 0 { - c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error) - } - return c -} - -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update -func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) -} - -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). - Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) -} - -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update -func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). - Set("gorm:update_column", true). - Set("gorm:save_associations", false). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (s *DB) Save(value interface{}) *DB { - scope := s.clone().NewScope(value) - if scope.PrimaryKeyZero() { - return scope.callCallbacks(s.parent.callbacks.creates).db - } - return scope.callCallbacks(s.parent.callbacks.updates).db -} - -// Create insert the value into database -func (s *DB) Create(value interface{}) *DB { - scope := s.clone().NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db -} - -// Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.Raw(true).Where(sql, values...).db -} - -// Exec execute raw sql -func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.clone().NewScope(nil) - generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) - generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") - scope.Raw(generatedSQL) - return scope.Exec().db -} - -// Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") -func (s *DB) Model(value interface{}) *DB { - c := s.clone() - c.Value = value - return c -} - -// Table specify the table you would like to run db operations -func (s *DB) Table(name string) *DB { - clone := s.clone() - clone.search.Table(name) - clone.Value = nil - return clone -} - -// Debug start debug mode -func (s *DB) Debug() *DB { - return s.clone().LogMode(true) -} - -// Begin begin a transaction -func (s *DB) Begin() *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok { - tx, err := db.Begin() - c.db = interface{}(tx).(sqlCommon) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - -// Commit commit a transaction -func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok { - s.AddError(db.Commit()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// Rollback rollback a transaction -func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok { - s.AddError(db.Rollback()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// NewRecord check if value's primary key is blank -func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() -} - -// RecordNotFound check if returning ErrRecordNotFound error -func (s *DB) RecordNotFound() bool { - for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { - return true - } - } - return false -} - -// CreateTable create table for models -func (s *DB) CreateTable(models ...interface{}) *DB { - db := s.Unscoped() - for _, model := range models { - db = db.NewScope(model).createTable().db - } - return db -} - -// DropTable drop table for models -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) - } - - db = db.NewScope(value).dropTable().db - } - return db -} - -// DropTableIfExists drop table if it is exist -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) - } - } - return db -} - -// HasTable check has table or not -func (s *DB) HasTable(value interface{}) bool { - var ( - scope = s.clone().NewScope(value) - tableName string - ) - - if name, ok := value.(string); ok { - tableName = name - } else { - tableName = scope.TableName() - } - - has := scope.Dialect().HasTable(tableName) - s.AddError(scope.db.Error) - return has -} - -// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.Unscoped() - for _, value := range values { - db = db.NewScope(value).autoMigrate().db - } - return db -} - -// ModifyColumn modify column to type -func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.clone().NewScope(s.Value) - scope.modifyColumn(column, typ) - return scope.db -} - -// DropColumn drop a column -func (s *DB) DropColumn(column string) *DB { - scope := s.clone().NewScope(s.Value) - scope.dropColumn(column) - return scope.db -} - -// AddIndex add index for columns with given name -func (s *DB) AddIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) - return scope.db -} - -// AddUniqueIndex add unique index for columns with given name -func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) - return scope.db -} - -// RemoveIndex remove index with name -func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.clone().NewScope(s.Value) - scope.removeIndex(indexName) - return scope.db -} - -// AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.clone().NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) - return scope.db -} - -// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode -func (s *DB) Association(column string) *Association { - var err error - scope := s.clone().NewScope(s.Value) - - if primaryField := scope.PrimaryField(); primaryField.IsBlank { - err = errors.New("primary key can't be nil") - } else { - if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { - err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) - } else { - return &Association{scope: scope, column: column, field: field} - } - } else { - err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) - } - } - - return &Association{Error: err} -} - -// Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.Preload(column, conditions...).db -} - -// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting -func (s *DB) Set(name string, value interface{}) *DB { - return s.clone().InstantSet(name, value) -} - -// InstantSet instant set setting, will affect current db -func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value - return s -} - -// Get get setting by name -func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] - return -} - -// SetJoinTableHandler set a model's join table handler for a relation -func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - scope := s.NewScope(source) - for _, field := range scope.GetModelStruct().StructFields { - if field.Name == column || field.DBName == column { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { - source := (&Scope{Value: source}).GetModelStruct().ModelType - destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType - handler.Setup(field.Relationship, many2many, source, destination) - field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { - s.Table(table).AutoMigrate(handler) - } - } - } - } -} - -// AddError add error to the db -func (s *DB) AddError(err error) error { - if err != nil { - if err != ErrRecordNotFound { - if s.logMode == 0 { - go s.print(fileWithLineNum(), err) - } else { - s.log(err) - } - - errors := Errors{errors: s.GetErrors()} - errors.Add(err) - if len(errors.GetErrors()) > 1 { - err = errors - } - } - - s.Error = err - } - return err -} - -// GetErrors get happened errors from the db -func (s *DB) GetErrors() (errors []error) { - if errs, ok := s.Error.(errorsInterface); ok { - return errs.GetErrors() - } else if s.Error != nil { - return []error{s.Error} - } - return -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.DB -//////////////////////////////////////////////////////////////////////////////// - -func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} - - for key, value := range s.values { - db.values[key] = value - } - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = &db - return &db -} - -func (s *DB) print(v ...interface{}) { - s.logger.(logger).Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == 2 { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == 2 { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars) - } -} diff --git a/main_test.go b/main_test.go deleted file mode 100644 index 1344c65b..00000000 --- a/main_test.go +++ /dev/null @@ -1,790 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" - - "github.com/erikstmartin/go-testdb" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mssql" - _ "github.com/jinzhu/gorm/dialects/mysql" - "github.com/jinzhu/gorm/dialects/postgres" - _ "github.com/jinzhu/gorm/dialects/sqlite" - "github.com/jinzhu/now" -) - -var ( - DB *gorm.DB - t1, t2, t3, t4, t5 time.Time -) - -func init() { - var err error - - if DB, err = OpenTestConnection(); err != nil { - panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) - } - - runMigration() -} - -func OpenTestConnection() (db *gorm.DB, err error) { - switch os.Getenv("GORM_DIALECT") { - case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; - fmt.Println("testing mysql...") - dbhost := os.Getenv("GORM_DBADDRESS") - if dbhost != "" { - dbhost = fmt.Sprintf("tcp(%v)", dbhost) - } - db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost)) - case "postgres": - fmt.Println("testing postgres...") - dbhost := os.Getenv("GORM_DBHOST") - if dbhost != "" { - dbhost = fmt.Sprintf("host=%v ", dbhost) - } - db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost)) - case "foundation": - fmt.Println("testing foundation...") - db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") - case "mssql": - fmt.Println("testing mssql...") - db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") - default: - fmt.Println("testing sqlite3...") - db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) - } - - // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) - // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if os.Getenv("DEBUG") == "true" { - db.LogMode(true) - } - - db.DB().SetMaxIdleConns(10) - - return -} - -func TestStringPrimaryKey(t *testing.T) { - type UUIDStruct struct { - ID string `gorm:"primary_key"` - Name string - } - DB.AutoMigrate(&UUIDStruct{}) - - data := UUIDStruct{ID: "uuid", Name: "hello"} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { - t.Errorf("string primary key should not be populated") - } -} - -func TestExceptionsWithInvalidSql(t *testing.T) { - var columns []string - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - var count1, count2 int64 - DB.Model(&User{}).Count(&count1) - if count1 <= 0 { - t.Errorf("Should find some users") - } - - if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - DB.Model(&User{}).Count(&count2) - if count1 != count2 { - t.Errorf("No user should not be deleted by invalid SQL") - } -} - -func TestSetTable(t *testing.T) { - DB.Create(getPreparedUser("pluck_user1", "pluck_user")) - DB.Create(getPreparedUser("pluck_user2", "pluck_user")) - DB.Create(getPreparedUser("pluck_user3", "pluck_user")) - - if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { - t.Error("No errors should happen if set table for pluck", err) - } - - var users []User - if DB.Table("users").Find(&[]User{}).Error != nil { - t.Errorf("No errors should happen if set table for find") - } - - if DB.Table("invalid_table").Find(&users).Error == nil { - t.Errorf("Should got error when table is set to an invalid table") - } - - DB.Exec("drop table deleted_users;") - if DB.Table("deleted_users").CreateTable(&User{}).Error != nil { - t.Errorf("Create table with specified table") - } - - DB.Table("deleted_users").Save(&User{Name: "DeletedUser"}) - - var deletedUsers []User - DB.Table("deleted_users").Find(&deletedUsers) - if len(deletedUsers) != 1 { - t.Errorf("Query from specified table") - } - - DB.Save(getPreparedUser("normal_user", "reset_table")) - DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) - var user1, user2, user3 User - DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) - if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") { - t.Errorf("unset specified table with blank string") - } -} - -type Order struct { -} - -type Cart struct { -} - -func (c Cart) TableName() string { - return "shopping_cart" -} - -func TestHasTable(t *testing.T) { - type Foo struct { - Id int - Stuff string - } - DB.DropTable(&Foo{}) - - // Table should not exist at this point, HasTable should return false - if ok := DB.HasTable("foos"); ok { - t.Errorf("Table should not exist, but does") - } - if ok := DB.HasTable(&Foo{}); ok { - t.Errorf("Table should not exist, but does") - } - - // We create the table - if err := DB.CreateTable(&Foo{}).Error; err != nil { - t.Errorf("Table should be created") - } - - // And now it should exits, and HasTable should return true - if ok := DB.HasTable("foos"); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } - if ok := DB.HasTable(&Foo{}); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } -} - -func TestTableName(t *testing.T) { - DB := DB.Model("") - if DB.NewScope(Order{}).TableName() != "orders" { - t.Errorf("Order's table name should be orders") - } - - if DB.NewScope(&Order{}).TableName() != "orders" { - t.Errorf("&Order's table name should be orders") - } - - if DB.NewScope([]Order{}).TableName() != "orders" { - t.Errorf("[]Order's table name should be orders") - } - - if DB.NewScope(&[]Order{}).TableName() != "orders" { - t.Errorf("&[]Order's table name should be orders") - } - - DB.SingularTable(true) - if DB.NewScope(Order{}).TableName() != "order" { - t.Errorf("Order's singular table name should be order") - } - - if DB.NewScope(&Order{}).TableName() != "order" { - t.Errorf("&Order's singular table name should be order") - } - - if DB.NewScope([]Order{}).TableName() != "order" { - t.Errorf("[]Order's singular table name should be order") - } - - if DB.NewScope(&[]Order{}).TableName() != "order" { - t.Errorf("&[]Order's singular table name should be order") - } - - if DB.NewScope(&Cart{}).TableName() != "shopping_cart" { - t.Errorf("&Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(Cart{}).TableName() != "shopping_cart" { - t.Errorf("Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" { - t.Errorf("&[]Cart's singular table name should be shopping_cart") - } - - if DB.NewScope([]Cart{}).TableName() != "shopping_cart" { - t.Errorf("[]Cart's singular table name should be shopping_cart") - } - DB.SingularTable(false) -} - -func TestNullValues(t *testing.T) { - DB.DropTable(&NullValue{}) - DB.AutoMigrate(&NullValue{}) - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: true}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv NullValue - DB.First(&nv, "name = ?", "hello") - - if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-2", Valid: true}, - Gender: &sql.NullString{String: "F", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv2 NullValue - DB.First(&nv2, "name = ?", "hello-2") - if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-3", Valid: false}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err == nil { - t.Errorf("Can't save because of name can't be null") - } -} - -func TestNullValuesWithFirstOrCreate(t *testing.T) { - var nv1 = NullValue{ - Name: sql.NullString{String: "first_or_create", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - } - - var nv2 NullValue - if err := DB.Where(nv1).FirstOrCreate(&nv2).Error; err != nil { - t.Errorf("Should not raise any error, but got %v", err) - } - - if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" { - t.Errorf("first or create with nullvalues") - } - - if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { - t.Errorf("Should not raise any error, but got %v", err) - } - - if nv2.Age.Int64 != 18 { - t.Errorf("should update age to 18") - } -} - -func TestTransaction(t *testing.T) { - tx := DB.Begin() - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record") - } - - if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { - t.Errorf("Should return the underlying sql.Tx") - } - - tx.Rollback() - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } - - tx2 := DB.Begin() - u2 := User{Name: "transcation-2"} - if err := tx2.Save(&u2).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record") - } - - tx2.Commit() - - if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record") - } -} - -func TestRow(t *testing.T) { - user1 := User{Name: "RowUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "RowUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "RowUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() - var age int64 - row.Scan(&age) - if age != 10 { - t.Errorf("Scan with Row") - } -} - -func TestRows(t *testing.T) { - user1 := User{Name: "RowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "RowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "RowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - count := 0 - for rows.Next() { - var name string - var age int64 - rows.Scan(&name, &age) - count++ - } - - if count != 2 { - t.Errorf("Should found two records") - } -} - -func TestScanRows(t *testing.T) { - user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - type Result struct { - Name string - Age int - } - - var results []Result - for rows.Next() { - var result Result - if err := DB.ScanRows(rows, &result); err != nil { - t.Errorf("should get no error, but got %v", err) - } - results = append(results, result) - } - - if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { - t.Errorf("Should find expected results") - } -} - -func TestScan(t *testing.T) { - user1 := User{Name: "ScanUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "ScanUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "ScanUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Age int - } - - var res result - DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res) - if res.Name != user3.Name { - t.Errorf("Scan into struct should work") - } - - var doubleAgeRes result - DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes) - if doubleAgeRes.Age != res.Age*2 { - t.Errorf("Scan double age as age") - } - - var ress []result - DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Scan into struct map") - } -} - -func TestRaw(t *testing.T) { - user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Email string - } - - var ress []result - DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Raw with scan") - } - - rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() - count := 0 - for rows.Next() { - count++ - } - if count != 1 { - t.Errorf("Raw with Rows should find one record with name 3") - } - - DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) - if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { - t.Error("Raw sql to update records") - } -} - -func TestGroup(t *testing.T) { - rows, err := DB.Select("name").Table("users").Group("name").Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - rows.Scan(&name) - } - } else { - t.Errorf("Should not raise any error") - } -} - -func TestJoins(t *testing.T) { - var user = User{ - Name: "joins", - CreditCard: CreditCard{Number: "411111111111"}, - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var users1 []User - DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1) - if len(users1) != 2 { - t.Errorf("should find two users using left join") - } - - var users2 []User - DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2) - if len(users2) != 1 { - t.Errorf("should find one users using left join with conditions") - } - - var users3 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3) - if len(users3) != 1 { - t.Errorf("should find one users using multiple left join conditions") - } - - var users4 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4) - if len(users4) != 0 { - t.Errorf("should find no user when searching with unexisting credit card") - } - - var users5 []User - db5 := DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where(User{Id:1}).Where(Email{Id:1}).Not(Email{Id:10}).First(&users5) - if db5.Error != nil { - t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) - } -} - -func TestJoinsWithSelect(t *testing.T) { - type result struct { - Name string - Email string - } - - user := User{ - Name: "joins_with_select", - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var results []result - DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) - if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { - t.Errorf("Should find all two emails with Join select") - } -} - -func TestHaving(t *testing.T) { - rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - var total int64 - rows.Scan(&name, &total) - - if name == "2" && total != 1 { - t.Errorf("Should have one user having name 2") - } - if name == "3" && total != 2 { - t.Errorf("Should have two users having name 3") - } - } - } else { - t.Errorf("Should not raise any error") - } -} - -func DialectHasTzSupport() bool { - // NB: mssql and FoundationDB do not support time zones. - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { - return false - } - return true -} - -func TestTimeWithZone(t *testing.T) { - var format = "2006-01-02 15:04:05 -0700" - var times []time.Time - GMT8, _ := time.LoadLocation("Asia/Shanghai") - times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) - times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) - - for index, vtime := range times { - name := "time_with_zone_" + strconv.Itoa(index) - user := User{Name: name, Birthday: vtime} - - if !DialectHasTzSupport() { - // If our driver dialect doesn't support TZ's, just use UTC for everything here. - user.Birthday = vtime.UTC() - } - - DB.Save(&user) - expectedBirthday := "2013-02-18 17:51:49 +0000" - foundBirthday := user.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - var findUser, findUser2, findUser3 User - DB.First(&findUser, "name = ?", name) - foundBirthday = findUser.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { - t.Errorf("User should be found") - } - - if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { - t.Errorf("User should not be found") - } - } -} - -func TestHstore(t *testing.T) { - type Details struct { - Id int64 - Bulk postgres.Hstore - } - - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { - t.Skip() - } - - if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { - fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m") - panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) - } - - DB.Exec("drop table details") - - if err := DB.CreateTable(&Details{}).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } - - bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait" - bulk := map[string]*string{ - "bankAccountId": &bankAccountId, - "phoneNumber": &phoneNumber, - "opinion": &opinion, - } - d := Details{Bulk: bulk} - DB.Save(&d) - - var d2 Details - if err := DB.First(&d2).Error; err != nil { - t.Errorf("Got error when tried to fetch details: %+v", err) - } - - for k := range bulk { - if r, ok := d2.Bulk[k]; ok { - if res, _ := bulk[k]; *res != *r { - t.Errorf("Details should be equal") - } - } else { - t.Errorf("Details should be existed") - } - } -} - -func TestSetAndGet(t *testing.T) { - if value, ok := DB.Set("hello", "world").Get("hello"); !ok { - t.Errorf("Should be able to get setting after set") - } else { - if value.(string) != "world" { - t.Errorf("Setted value should not be changed") - } - } - - if _, ok := DB.Get("non_existing"); ok { - t.Errorf("Get non existing key should return error") - } -} - -func TestCompatibilityMode(t *testing.T) { - DB, _ := gorm.Open("testdb", "") - testdb.SetQueryFunc(func(query string) (driver.Rows, error) { - columns := []string{"id", "name", "age"} - result := ` - 1,Tim,20 - 2,Joe,25 - 3,Bob,30 - ` - return testdb.RowsFromCSVString(columns, result), nil - }) - - var users []User - DB.Find(&users) - if (users[0].Name != "Tim") || len(users) != 3 { - t.Errorf("Unexcepted result returned") - } -} - -func TestOpenExistingDB(t *testing.T) { - DB.Save(&User{Name: "jnfeinstein"}) - dialect := os.Getenv("GORM_DIALECT") - - db, err := gorm.Open(dialect, DB.DB()) - if err != nil { - t.Errorf("Should have wrapped the existing DB connection") - } - - var user User - if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { - t.Errorf("Should have found existing record") - } -} - -func TestDdlErrors(t *testing.T) { - var err error - - if err = DB.Close(); err != nil { - t.Errorf("Closing DDL test db connection err=%s", err) - } - defer func() { - // Reopen DB connection. - if DB, err = OpenTestConnection(); err != nil { - t.Fatalf("Failed re-opening db connection: %s", err) - } - }() - - if err := DB.Find(&User{}).Error; err == nil { - t.Errorf("Expected operation on closed db to produce an error, but err was nil") - } -} - -func BenchmarkGorm(b *testing.B) { - b.N = 2000 - for x := 0; x < b.N; x++ { - e := strconv.Itoa(x) + "benchmark@example.org" - email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} - // Insert - DB.Save(&email) - // Query - DB.First(&BigEmail{}, "email = ?", e) - // Update - DB.Model(&email).UpdateColumn("email", "new-"+e) - // Delete - DB.Delete(&email) - } -} - -func BenchmarkRawSql(b *testing.B) { - DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") - DB.SetMaxIdleConns(10) - insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" - querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" - updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" - deleteSql := "DELETE FROM orders WHERE id = $1" - - b.N = 2000 - for x := 0; x < b.N; x++ { - var id int64 - e := strconv.Itoa(x) + "benchmark@example.org" - email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} - // Insert - DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) - // Query - rows, _ := DB.Query(querySql, email.Email) - rows.Close() - // Update - DB.Exec(updateSql, "new-"+e, time.Now(), id) - // Delete - DB.Exec(deleteSql, id) - } -} diff --git a/migration_test.go b/migration_test.go deleted file mode 100644 index 30085263..00000000 --- a/migration_test.go +++ /dev/null @@ -1,435 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type User struct { - Id int64 - Age int64 - UserNum Num - Name string `sql:"size:255"` - Email string - Birthday time.Time // Time - CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically - UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically - Emails []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressID sql.NullInt64 // Embedded struct's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct's foreign key - CreditCard CreditCard - Latitude float64 - Languages []Language `gorm:"many2many:user_languages;"` - CompanyID *int - Company Company - Role - PasswordHash []byte - Sequence uint `gorm:"AUTO_INCREMENT"` - IgnoreMe int64 `sql:"-"` - IgnoreStringSlice []string `sql:"-"` - Ignored struct{ Name string } `sql:"-"` - IgnoredPointer *User `sql:"-"` -} - -type NotSoLongTableName struct { - Id int64 - ReallyLongThingID int64 - ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit -} - -type ReallyLongTableNameToTestMySQLNameLengthLimit struct { - Id int64 -} - -type ReallyLongThingThatReferencesShort struct { - Id int64 - ShortID int64 - Short Short -} - -type Short struct { - Id int64 -} - -type CreditCard struct { - ID int8 - Number string - UserId sql.NullInt64 - CreatedAt time.Time `sql:"not null"` - UpdatedAt time.Time - DeletedAt *time.Time -} - -type Email struct { - Id int16 - UserId int - Email string `sql:"type:varchar(100);"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type Address struct { - ID int - Address1 string - Address2 string - Post string - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time -} - -type Language struct { - gorm.Model - Name string - Users []User `gorm:"many2many:user_languages;"` -} - -type Product struct { - Id int64 - Code string - Price int64 - CreatedAt time.Time - UpdatedAt time.Time - AfterFindCallTimes int64 - BeforeCreateCallTimes int64 - AfterCreateCallTimes int64 - BeforeUpdateCallTimes int64 - AfterUpdateCallTimes int64 - BeforeSaveCallTimes int64 - AfterSaveCallTimes int64 - BeforeDeleteCallTimes int64 - AfterDeleteCallTimes int64 -} - -type Company struct { - Id int64 - Name string - Owner *User `sql:"-"` -} - -type Role struct { - Name string `gorm:"size:256"` -} - -func (role *Role) Scan(value interface{}) error { - if b, ok := value.([]uint8); ok { - role.Name = string(b) - } else { - role.Name = value.(string) - } - return nil -} - -func (role Role) Value() (driver.Value, error) { - return role.Name, nil -} - -func (role Role) IsAdmin() bool { - return role.Name == "admin" -} - -type Num int64 - -func (i *Num) Scan(src interface{}) error { - switch s := src.(type) { - case []byte: - case int64: - *i = Num(s) - default: - return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) - } - return nil -} - -type Animal struct { - Counter uint64 `gorm:"primary_key:yes"` - Name string `sql:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name - Age time.Time `sql:"DEFAULT:current_timestamp"` - unexported string // unexported value - CreatedAt time.Time - UpdatedAt time.Time -} - -type JoinTable struct { - From uint64 - To uint64 - Time time.Time `sql:"default: null"` -} - -type Post struct { - Id int64 - CategoryId sql.NullInt64 - MainCategoryId int64 - Title string - Body string - Comments []*Comment - Category Category - MainCategory Category -} - -type Category struct { - gorm.Model - Name string -} - -type Comment struct { - gorm.Model - PostId int64 - Content string - Post Post -} - -// Scanner -type NullValue struct { - Id int64 - Name sql.NullString `sql:"not null"` - Gender *sql.NullString `sql:"not null"` - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - AddedAt NullTime -} - -type NullTime struct { - Time time.Time - Valid bool -} - -func (nt *NullTime) Scan(value interface{}) error { - if value == nil { - nt.Valid = false - return nil - } - nt.Time, nt.Valid = value.(time.Time), true - return nil -} - -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} - -func getPreparedUser(name string, role string) *User { - var company Company - DB.Where(Company{Name: role}).FirstOrCreate(&company) - - return &User{ - Name: name, - Age: 20, - Role: Role{role}, - BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, - ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, - CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, - Emails: []Email{ - {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, - }, - Company: company, - Languages: []Language{ - {Name: fmt.Sprintf("lang_1_%v", name)}, - {Name: fmt.Sprintf("lang_2_%v", name)}, - }, - } -} - -func runMigration() { - if err := DB.DropTableIfExists(&User{}).Error; err != nil { - fmt.Printf("Got error when try to delete table users, %+v\n", err) - } - - for _, table := range []string{"animals", "user_languages"} { - DB.Exec(fmt.Sprintf("drop table %v;", table)) - } - - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Toy{}, &ElementWithIgnoredField{}} - for _, value := range values { - DB.DropTable(value) - } - - if err := DB.AutoMigrate(values...).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } -} - -func TestIndexes(t *testing.T) { - if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email should have index idx_email_email") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email's index idx_email_email should be deleted") - } - - if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { - t.Errorf("Should get to create duplicate record when having unique index") - } - - var user = User{Name: "sample_user"} - DB.Save(&user) - if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { - t.Errorf("Should get no error when append two emails for user") - } - - if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { - t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { - t.Errorf("Should be able to create duplicated emails after remove unique index") - } -} - -type BigEmail struct { - Id int64 - UserId int64 - Email string `sql:"index:idx_email_agent"` - UserAgent string `sql:"index:idx_email_agent"` - RegisteredAt time.Time `sql:"unique_index"` - CreatedAt time.Time - UpdatedAt time.Time -} - -func (b BigEmail) TableName() string { - return "emails" -} - -func TestAutoMigration(t *testing.T) { - DB.AutoMigrate(&Address{}) - if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } - - DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}) - - scope := DB.NewScope(&BigEmail{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") { - t.Errorf("Failed to create index") - } - - var bigemail BigEmail - DB.First(&bigemail, "user_agent = ?", "pc") - if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { - t.Error("Big Emails should be saved and fetched correctly") - } -} - -type MultipleIndexes struct { - ID int64 - UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` - Name string `sql:"unique_index:uix_multipleindexes_user_name"` - Email string `sql:"unique_index:,uix_multipleindexes_user_email"` - Other string `sql:"index:,idx_multipleindexes_user_other"` -} - -func TestMultipleIndexes(t *testing.T) { - if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { - fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) - } - - DB.AutoMigrate(&MultipleIndexes{}) - if err := DB.AutoMigrate(&BigEmail{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } - - DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) - - scope := DB.NewScope(&MultipleIndexes{}) - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { - t.Errorf("Failed to create index") - } - - var mutipleIndexes MultipleIndexes - DB.First(&mutipleIndexes, "name = ?", "jinzhu") - if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" { - t.Error("MutipleIndexes should be saved and fetched correctly") - } - - // Check unique constraints - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } -} diff --git a/migrator.go b/migrator.go new file mode 100644 index 00000000..0e01f567 --- /dev/null +++ b/migrator.go @@ -0,0 +1,109 @@ +package gorm + +import ( + "reflect" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// Migrator returns migrator +func (db *DB) Migrator() Migrator { + tx := db.getInstance() + + // apply scopes to migrator + for len(tx.Statement.scopes) > 0 { + tx = tx.executeScopes() + } + + return tx.Dialector.Migrator(tx.Session(&Session{})) +} + +// AutoMigrate run auto migration for given models +func (db *DB) AutoMigrate(dst ...interface{}) error { + return db.Migrator().AutoMigrate(dst...) +} + +// ViewOption view option +type ViewOption struct { + Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE` + CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION` + Query *DB // required subquery. +} + +// ColumnType column type interface +type ColumnType interface { + Name() string + DatabaseTypeName() string // varchar + ColumnType() (columnType string, ok bool) // varchar(64) + PrimaryKey() (isPrimaryKey bool, ok bool) + AutoIncrement() (isAutoIncrement bool, ok bool) + Length() (length int64, ok bool) + DecimalSize() (precision int64, scale int64, ok bool) + Nullable() (nullable bool, ok bool) + Unique() (unique bool, ok bool) + ScanType() reflect.Type + Comment() (value string, ok bool) + DefaultValue() (value string, ok bool) +} + +type Index interface { + Table() string + Name() string + Columns() []string + PrimaryKey() (isPrimaryKey bool, ok bool) + Unique() (unique bool, ok bool) + Option() string +} + +// TableType table type interface +type TableType interface { + Schema() string + Name() string + Type() string + Comment() (comment string, ok bool) +} + +// Migrator migrator interface +type Migrator interface { + // AutoMigrate + AutoMigrate(dst ...interface{}) error + + // Database + CurrentDatabase() string + FullDataTypeOf(*schema.Field) clause.Expr + GetTypeAliases(databaseTypeName string) []string + + // Tables + CreateTable(dst ...interface{}) error + DropTable(dst ...interface{}) error + HasTable(dst interface{}) bool + RenameTable(oldName, newName interface{}) error + GetTables() (tableList []string, err error) + TableType(dst interface{}) (TableType, error) + + // Columns + AddColumn(dst interface{}, field string) error + DropColumn(dst interface{}, field string) error + AlterColumn(dst interface{}, field string) error + MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + HasColumn(dst interface{}, field string) bool + RenameColumn(dst interface{}, oldName, field string) error + ColumnTypes(dst interface{}) ([]ColumnType, error) + + // Views + CreateView(name string, option ViewOption) error + DropView(name string) error + + // Constraints + CreateConstraint(dst interface{}, name string) error + DropConstraint(dst interface{}, name string) error + HasConstraint(dst interface{}, name string) bool + + // Indexes + CreateIndex(dst interface{}, name string) error + DropIndex(dst interface{}, name string) error + HasIndex(dst interface{}, name string) bool + RenameIndex(dst interface{}, oldName, newName string) error + GetIndexes(dst interface{}) ([]Index, error) +} diff --git a/migrator/column_type.go b/migrator/column_type.go new file mode 100644 index 00000000..c6fdd6b2 --- /dev/null +++ b/migrator/column_type.go @@ -0,0 +1,107 @@ +package migrator + +import ( + "database/sql" + "reflect" +) + +// ColumnType column type implements ColumnType interface +type ColumnType struct { + SQLColumnType *sql.ColumnType + NameValue sql.NullString + DataTypeValue sql.NullString + ColumnTypeValue sql.NullString + PrimaryKeyValue sql.NullBool + UniqueValue sql.NullBool + AutoIncrementValue sql.NullBool + LengthValue sql.NullInt64 + DecimalSizeValue sql.NullInt64 + ScaleValue sql.NullInt64 + NullableValue sql.NullBool + ScanTypeValue reflect.Type + CommentValue sql.NullString + DefaultValueValue sql.NullString +} + +// Name returns the name or alias of the column. +func (ct ColumnType) Name() string { + if ct.NameValue.Valid { + return ct.NameValue.String + } + return ct.SQLColumnType.Name() +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned, then the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", +// "INT", and "BIGINT". +func (ct ColumnType) DatabaseTypeName() string { + if ct.DataTypeValue.Valid { + return ct.DataTypeValue.String + } + return ct.SQLColumnType.DatabaseTypeName() +} + +// ColumnType returns the database type of the column. like `varchar(16)` +func (ct ColumnType) ColumnType() (columnType string, ok bool) { + return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid +} + +// PrimaryKey returns the column is primary key or not. +func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { + return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid +} + +// AutoIncrement returns the column is auto increment or not. +func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) { + return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid +} + +// Length returns the column type length for variable length column types +func (ct ColumnType) Length() (length int64, ok bool) { + if ct.LengthValue.Valid { + return ct.LengthValue.Int64, true + } + return ct.SQLColumnType.Length() +} + +// DecimalSize returns the scale and precision of a decimal type. +func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) { + if ct.DecimalSizeValue.Valid { + return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true + } + return ct.SQLColumnType.DecimalSize() +} + +// Nullable reports whether the column may be null. +func (ct ColumnType) Nullable() (nullable bool, ok bool) { + if ct.NullableValue.Valid { + return ct.NullableValue.Bool, true + } + return ct.SQLColumnType.Nullable() +} + +// Unique reports whether the column may be unique. +func (ct ColumnType) Unique() (unique bool, ok bool) { + return ct.UniqueValue.Bool, ct.UniqueValue.Valid +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +func (ct ColumnType) ScanType() reflect.Type { + if ct.ScanTypeValue != nil { + return ct.ScanTypeValue + } + return ct.SQLColumnType.ScanType() +} + +// Comment returns the comment of current column. +func (ct ColumnType) Comment() (value string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} + +// DefaultValue returns the default value of current column. +func (ct ColumnType) DefaultValue() (value string, ok bool) { + return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid +} diff --git a/migrator/index.go b/migrator/index.go new file mode 100644 index 00000000..8845da95 --- /dev/null +++ b/migrator/index.go @@ -0,0 +1,43 @@ +package migrator + +import "database/sql" + +// Index implements gorm.Index interface +type Index struct { + TableName string + NameValue string + ColumnList []string + PrimaryKeyValue sql.NullBool + UniqueValue sql.NullBool + OptionValue string +} + +// Table return the table name of the index. +func (idx Index) Table() string { + return idx.TableName +} + +// Name return the name of the index. +func (idx Index) Name() string { + return idx.NameValue +} + +// Columns return the columns of the index +func (idx Index) Columns() []string { + return idx.ColumnList +} + +// PrimaryKey returns the index is primary key or not. +func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) { + return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid +} + +// Unique returns whether the index is unique or not. +func (idx Index) Unique() (unique bool, ok bool) { + return idx.UniqueValue.Bool, idx.UniqueValue.Valid +} + +// Option return the optional attribute of the index +func (idx Index) Option() string { + return idx.OptionValue +} diff --git a/migrator/migrator.go b/migrator/migrator.go new file mode 100644 index 00000000..de60f91c --- /dev/null +++ b/migrator/migrator.go @@ -0,0 +1,956 @@ +package migrator + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "regexp" + "strings" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) + +// Migrator m struct +type Migrator struct { + Config +} + +// Config schema config +type Config struct { + CreateIndexAfterCreateTable bool + DB *gorm.DB + gorm.Dialector +} + +type printSQLLogger struct { + logger.Interface +} + +func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + fmt.Println(sql + ";") + l.Interface.Trace(ctx, begin, fc, err) +} + +// GormDataTypeInterface gorm data type interface +type GormDataTypeInterface interface { + GormDBDataType(*gorm.DB, *schema.Field) string +} + +// RunWithValue run migration with statement value +func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { + stmt := &gorm.Statement{DB: m.DB} + if m.DB.Statement != nil { + stmt.Table = m.DB.Statement.Table + stmt.TableExpr = m.DB.Statement.TableExpr + } + + if table, ok := value.(string); ok { + stmt.Table = table + } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { + return err + } + + return fc(stmt) +} + +// DataTypeOf return field's db data type +func (m Migrator) DataTypeOf(field *schema.Field) string { + fieldValue := reflect.New(field.IndirectFieldType) + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { + return dataType + } + } + + return m.Dialector.DataTypeOf(field) +} + +// FullDataTypeOf returns field's db full data type +func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { + expr.SQL = m.DataTypeOf(field) + + if field.NotNull { + expr.SQL += " NOT NULL" + } + + if field.Unique { + expr.SQL += " UNIQUE" + } + + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) + } else if field.DefaultValue != "(-)" { + expr.SQL += " DEFAULT " + field.DefaultValue + } + } + + return +} + +// AutoMigrate auto migrate values +func (m Migrator) AutoMigrate(values ...interface{}) error { + for _, value := range m.ReorderModels(values, true) { + queryTx := m.DB.Session(&gorm.Session{}) + execTx := queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + if !queryTx.Migrator().HasTable(value) { + if err := execTx.Migrator().CreateTable(value); err != nil { + return err + } + } else { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + columnTypes, err := queryTx.Migrator().ColumnTypes(value) + if err != nil { + return err + } + var ( + parseIndexes = stmt.Schema.ParseIndexes() + parseCheckConstraints = stmt.Schema.ParseCheckConstraints() + ) + for _, dbName := range stmt.Schema.DBNames { + var foundColumn gorm.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == dbName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column + if err = execTx.Migrator().AddColumn(value, dbName); err != nil { + return err + } + } else { + // found, smartly migrate + field := stmt.Schema.FieldsByDBName[dbName] + if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + return err + } + } + } + + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if constraint := rel.ParseConstraint(); constraint != nil && + constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { + if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } + } + } + } + + for _, chk := range parseCheckConstraints { + if !queryTx.Migrator().HasConstraint(value, chk.Name) { + if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } + + for _, idx := range parseIndexes { + if !queryTx.Migrator().HasIndex(value, idx.Name) { + if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + + return nil + }); err != nil { + return err + } + } + } + + return nil +} + +// GetTables returns tables +func (m Migrator) GetTables() (tableList []string, err error) { + err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). + Scan(&tableList).Error + return +} + +// CreateTable create table in database for values +func (m Migrator) CreateTable(values ...interface{}) error { + for _, value := range m.ReorderModels(values, false) { + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + var ( + createTableSQL = "CREATE TABLE ? (" + values = []interface{}{m.CurrentTable(stmt)} + hasPrimaryKeyInDataType bool + ) + + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] + if !field.IgnoreMigration { + createTableSQL += "? ?" + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) + createTableSQL += "," + } + } + + if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { + createTableSQL += "PRIMARY KEY ?," + primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields)) + for _, field := range stmt.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) + } + + values = append(values, primaryKeys) + } + + for _, idx := range stmt.Schema.ParseIndexes() { + if m.CreateIndexAfterCreateTable { + defer func(value interface{}, name string) { + if err == nil { + err = tx.Migrator().CreateIndex(value, name) + } + }(value, idx.Name) + } else { + if idx.Class != "" { + createTableSQL += idx.Class + " " + } + createTableSQL += "INDEX ? ?" + + if idx.Comment != "" { + createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + + if idx.Option != "" { + createTableSQL += " " + idx.Option + } + + createTableSQL += "," + values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + } + } + + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + createTableSQL += "CONSTRAINT ? CHECK (?)," + values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) + } + + createTableSQL = strings.TrimSuffix(createTableSQL, ",") + + createTableSQL += ")" + + if tableOption, ok := m.DB.Get("gorm:table_options"); ok { + createTableSQL += fmt.Sprint(tableOption) + } + + err = tx.Exec(createTableSQL, values...).Error + return err + }); err != nil { + return err + } + } + return nil +} + +// DropTable drop table for values +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + for i := len(values) - 1; i >= 0; i-- { + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error + }); err != nil { + return err + } + } + return nil +} + +// HasTable returns table exists or not for value, value could be a struct or string +func (m Migrator) HasTable(value interface{}) bool { + var count int64 + + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) + }) + + return count > 0 +} + +// RenameTable rename table from oldName to newName +func (m Migrator) RenameTable(oldName, newName interface{}) error { + var oldTable, newTable interface{} + if v, ok := oldName.(string); ok { + oldTable = clause.Table{Name: v} + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(oldName); err == nil { + oldTable = m.CurrentTable(stmt) + } else { + return err + } + } + + if v, ok := newName.(string); ok { + newTable = clause.Table{Name: v} + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(newName); err == nil { + newTable = m.CurrentTable(stmt) + } else { + return err + } + } + + return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error +} + +// AddColumn create `name` column for value +func (m Migrator) AddColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // avoid using the same name field + f := stmt.Schema.LookUpField(name) + if f == nil { + return fmt.Errorf("failed to look up field with name: %s", name) + } + + if !f.IgnoreMigration { + return m.DB.Exec( + "ALTER TABLE ? ADD ? ?", + m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f), + ).Error + } + + return nil + }) +} + +// DropColumn drop value's `name` column +func (m Migrator) DropColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name}, + ).Error + }) +} + +// AlterColumn alter value's `field` column' type based on schema definition +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + fileType := m.FullDataTypeOf(field) + return m.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? TYPE ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, + ).Error + + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +// HasColumn check has column `field` for value or not +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +// RenameColumn rename value's field name from oldName to newName +func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName + } + + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? RENAME COLUMN ? TO ?", + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +// MigrateColumn migrate column +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + // found, smart migrate + fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) + realDataType := strings.ToLower(columnType.DatabaseTypeName()) + + var ( + alterColumn bool + isSameType = fullDataType == realDataType + ) + + if !field.PrimaryKey { + // check type + if !strings.HasPrefix(fullDataType, realDataType) { + // check type aliases + aliases := m.DB.Migrator().GetTypeAliases(realDataType) + for _, alias := range aliases { + if strings.HasPrefix(fullDataType, alias) { + isSameType = true + break + } + } + + if !isSameType { + alterColumn = true + } + } + } + + if !isSameType { + // check size + if length, ok := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if !field.PrimaryKey && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true + } + } + } + + // check nullable + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + // not primary key & database is nullable + if !field.PrimaryKey && nullable { + alterColumn = true + } + } + + // check unique + if unique, ok := columnType.Unique(); ok && unique != field.Unique { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check default value + if !field.PrimaryKey { + currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) + dv, dvNotNull := columnType.DefaultValue() + if dvNotNull && !currentDefaultNotNull { + // default value -> null + alterColumn = true + } else if !dvNotNull && currentDefaultNotNull { + // null -> default value + alterColumn = true + } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || + (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { + // default value not equal + // not both null + if currentDefaultNotNull || dvNotNull { + alterColumn = true + } + } + } + + // check comment + if comment, ok := columnType.Comment(); ok && comment != field.Comment { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + if alterColumn && !field.IgnoreMigration { + return m.DB.Migrator().AlterColumn(value, field.DBName) + } + + return nil +} + +// ColumnTypes return columnTypes []gorm.ColumnType and execErr error +func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { + columnTypes := make([]gorm.ColumnType, 0) + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() + if err != nil { + return err + } + + defer func() { + err = rows.Close() + }() + + var rawColumnTypes []*sql.ColumnType + rawColumnTypes, err = rows.ColumnTypes() + if err != nil { + return err + } + + for _, c := range rawColumnTypes { + columnTypes = append(columnTypes, ColumnType{SQLColumnType: c}) + } + + return + }) + + return columnTypes, execErr +} + +// CreateView create view from Query in gorm.ViewOption. +// Query in gorm.ViewOption is a [subquery] +// +// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20 +// q := DB.Model(&User{}).Where("age > ?", 20) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q}) +// +// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION +// q := DB.Model(&User{}) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"}) +// +// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery +func (m Migrator) CreateView(name string, option gorm.ViewOption) error { + if option.Query == nil { + return gorm.ErrSubQueryRequired + } + + sql := new(strings.Builder) + sql.WriteString("CREATE ") + if option.Replace { + sql.WriteString("OR REPLACE ") + } + sql.WriteString("VIEW ") + m.QuoteTo(sql, name) + sql.WriteString(" AS ") + + m.DB.Statement.AddVar(sql, option.Query) + + if option.CheckOption != "" { + sql.WriteString(" ") + sql.WriteString(option.CheckOption) + } + return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error +} + +// DropView drop view +func (m Migrator) DropView(name string) error { + return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error +} + +func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + +// GuessConstraintAndTable guess statement's constraint and it's table based on name +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { + if stmt.Schema == nil { + return nil, nil, stmt.Table + } + + checkConstraints := stmt.Schema.ParseCheckConstraints() + if chk, ok := checkConstraints[name]; ok { + return nil, &chk, stmt.Table + } + + getTable := func(rel *schema.Relationship) string { + switch rel.Type { + case schema.HasOne, schema.HasMany: + return rel.FieldSchema.Table + case schema.Many2Many: + return rel.JoinTable.Table + } + return stmt.Table + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { + return constraint, nil, getTable(rel) + } + } + + if field := stmt.Schema.LookUpField(name); field != nil { + for k := range checkConstraints { + if checkConstraints[k].Field == field { + v := checkConstraints[k] + return nil, &v, stmt.Table + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { + return constraint, nil, getTable(rel) + } + } + } + + return nil, nil, stmt.Schema.Table +} + +// CreateConstraint create constraint +func (m Migrator) CreateConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if chk != nil { + return m.DB.Exec( + "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", + m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, + ).Error + } + + if constraint != nil { + vars := []interface{}{clause.Table{Name: table}} + if stmt.TableExpr != nil { + vars[0] = stmt.TableExpr + } + sql, values := buildConstraint(constraint) + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error + } + + return nil + }) +} + +// DropConstraint drop constraint +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error + }) +} + +// HasConstraint check has constraint or not +func (m Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", + currentDatabase, table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +// BuildIndexOptions build index options +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } else if opt.Length > 0 { + str += fmt.Sprintf("(%d)", opt.Length) + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +// BuildIndexOptionsInterface build index options interface +type BuildIndexOptionsInterface interface { + BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} +} + +// CreateIndex create index `name` +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ? ON ??" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + + if idx.Comment != "" { + createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + + if idx.Option != "" { + createIndexSQL += " " + idx.Option + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + + return fmt.Errorf("failed to create index with name %s", name) + }) +} + +// DropIndex drop index `name` +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error + }) +} + +// HasIndex check has index `name` or not +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Raw( + "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +// RenameIndex rename index from oldName to newName +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( + "ALTER TABLE ? RENAME INDEX ? TO ?", + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +// CurrentDatabase returns current database name +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) + return +} + +// ReorderModels reorder models according to constraint dependencies +func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { + type Dependency struct { + *gorm.Statement + Depends []*schema.Schema + } + + var ( + modelNames, orderedModelNames []string + orderedModelNamesMap = map[string]bool{} + parsedSchemas = map[*schema.Schema]bool{} + valuesMap = map[string]Dependency{} + insertIntoOrderedList func(name string) + parseDependence func(value interface{}, addToList bool) + ) + + parseDependence = func(value interface{}, addToList bool) { + dep := Dependency{ + Statement: &gorm.Statement{DB: m.DB, Dest: value}, + } + beDependedOn := map[*schema.Schema]bool{} + // support for special table name + if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil { + m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) + } + if _, ok := parsedSchemas[dep.Statement.Schema]; ok { + return + } + parsedSchemas[dep.Statement.Schema] = true + + if !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range dep.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { + dep.Depends = append(dep.Depends, c.ReferenceSchema) + } + + if rel.Type == schema.HasOne || rel.Type == schema.HasMany { + beDependedOn[rel.FieldSchema] = true + } + + if rel.JoinTable != nil { + // append join value + defer func(rel *schema.Relationship, joinValue interface{}) { + if !beDependedOn[rel.FieldSchema] { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } else { + fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() + parseDependence(fieldValue, autoAdd) + } + parseDependence(joinValue, autoAdd) + }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) + } + } + } + + valuesMap[dep.Schema.Table] = dep + + if addToList { + modelNames = append(modelNames, dep.Schema.Table) + } + } + + insertIntoOrderedList = func(name string) { + if _, ok := orderedModelNamesMap[name]; ok { + return // avoid loop + } + orderedModelNamesMap[name] = true + + if autoAdd { + dep := valuesMap[name] + for _, d := range dep.Depends { + if _, ok := valuesMap[d.Table]; ok { + insertIntoOrderedList(d.Table) + } else { + parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) + insertIntoOrderedList(d.Table) + } + } + } + + orderedModelNames = append(orderedModelNames, name) + } + + for _, value := range values { + if v, ok := value.(string); ok { + results = append(results, v) + } else { + parseDependence(value, true) + } + } + + for _, name := range modelNames { + insertIntoOrderedList(name) + } + + for _, name := range orderedModelNames { + results = append(results, valuesMap[name].Statement.Dest) + } + return +} + +// CurrentTable returns current statement's table expression +func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { + if stmt.TableExpr != nil { + return *stmt.TableExpr + } + return clause.Table{Name: stmt.Table} +} + +// GetIndexes return Indexes []gorm.Index and execErr error +func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { + return nil, errors.New("not support") +} + +// GetTypeAliases return database type aliases +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return nil +} + +// TableType return tableType gorm.TableType and execErr error +func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) { + return nil, errors.New("not support") +} diff --git a/migrator/table_type.go b/migrator/table_type.go new file mode 100644 index 00000000..ed6e42a0 --- /dev/null +++ b/migrator/table_type.go @@ -0,0 +1,33 @@ +package migrator + +import ( + "database/sql" +) + +// TableType table type implements TableType interface +type TableType struct { + SchemaValue string + NameValue string + TypeValue string + CommentValue sql.NullString +} + +// Schema returns the schema of the table. +func (ct TableType) Schema() string { + return ct.SchemaValue +} + +// Name returns the name of the table. +func (ct TableType) Name() string { + return ct.NameValue +} + +// Type returns the type of the table. +func (ct TableType) Type() string { + return ct.TypeValue +} + +// Comment returns the comment of current table. +func (ct TableType) Comment() (comment string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} diff --git a/model.go b/model.go index f37ff7ea..fa705df1 100644 --- a/model.go +++ b/model.go @@ -2,13 +2,15 @@ package gorm import "time" -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models -// type User struct { -// gorm.Model -// } +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embedded into your model or you may build your own model without it +// +// type User struct { +// gorm.Model +// } type Model struct { - ID uint `gorm:"primary_key"` + ID uint `gorm:"primarykey"` CreatedAt time.Time UpdatedAt time.Time - DeletedAt *time.Time `sql:"index"` + DeletedAt DeletedAt `gorm:"index"` } diff --git a/model_struct.go b/model_struct.go deleted file mode 100644 index d9e84c3c..00000000 --- a/model_struct.go +++ /dev/null @@ -1,559 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "go/ast" - "reflect" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" -) - -// DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { - return defaultTableName -} - -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() - -// ModelStruct model definition -type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - defaultTableName string -} - -// TableName get model's table name -func (s *ModelStruct) TableName(db *DB) string { - return DefaultTableNameHandler(db, s.defaultTableName) -} - -// StructField model field's struct definition -type StructField struct { - DBName string - Name string - Names []string - IsPrimaryKey bool - IsNormal bool - IsIgnored bool - IsScanner bool - HasDefaultValue bool - Tag reflect.StructTag - TagSettings map[string]string - Struct reflect.StructField - IsForeignKey bool - Relationship *Relationship -} - -func (structField *StructField) clone() *StructField { - clone := &StructField{ - DBName: structField.DBName, - Name: structField.Name, - Names: structField.Names, - IsPrimaryKey: structField.IsPrimaryKey, - IsNormal: structField.IsNormal, - IsIgnored: structField.IsIgnored, - IsScanner: structField.IsScanner, - HasDefaultValue: structField.HasDefaultValue, - Tag: structField.Tag, - TagSettings: map[string]string{}, - Struct: structField.Struct, - IsForeignKey: structField.IsForeignKey, - Relationship: structField.Relationship, - } - - for key, value := range structField.TagSettings { - clone.TagSettings[key] = value - } - - return clone -} - -// Relationship described the relationship between models -type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface -} - -func getForeignField(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { - return field - } - } - return nil -} - -// GetModelStruct get value's model struct, relationships based on struct and tag definition -func (scope *Scope) GetModelStruct() *ModelStruct { - var modelStruct ModelStruct - // Scope value can't be nil - if scope.Value == nil { - return &modelStruct - } - - reflectType := reflect.ValueOf(scope.Value).Type() - for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Scope value need to be a struct - if reflectType.Kind() != reflect.Struct { - return &modelStruct - } - - // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value - } - - modelStruct.ModelType = reflectType - - // Set default table name - if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok { - modelStruct.defaultTableName = tabler.TableName() - } else { - tableName := ToDBName(reflectType.Name()) - if scope.db == nil || !scope.db.parent.singularTable { - tableName = inflection.Plural(tableName) - } - modelStruct.defaultTableName = tableName - } - - // Get all fields - for i := 0; i < reflectType.NumField(); i++ { - if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), - } - - // is ignored field - if _, ok := field.TagSettings["-"]; ok { - field.IsIgnored = true - } else { - if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := field.TagSettings["DEFAULT"]; ok { - field.HasDefaultValue = true - } - - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - indirectType := fieldStruct.Type - for indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - fieldValue := reflect.New(indirectType).Interface() - if _, isScanner := fieldValue.(sql.Scanner); isScanner { - // is scanner - field.IsScanner, field.IsNormal = true, true - if indirectType.Kind() == reflect.Struct { - for i := 0; i < indirectType.NumField(); i++ { - for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - field.TagSettings[key] = value - } - } - } - } else if _, isTime := fieldValue.(*time.Time); isTime { - // is time - field.IsNormal = true - } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - // is embedded struct - for _, subField := range scope.New(fieldValue).GetStructFields() { - subField = subField.clone() - subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if subField.IsPrimaryKey { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) - } - modelStruct.StructFields = append(modelStruct.StructFields, subField) - } - continue - } else { - // build relationships - switch indirectType.Kind() { - case reflect.Slice: - defer func(field *StructField) { - var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys []string - associationForeignKeys []string - elemType = field.Struct.Type - ) - - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") - } - - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") - } - - for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { - relationship.Kind = "many_to_many" - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) - } - } - - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) - } - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for _, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - // User has many comments, associationType is User, comment use UserID as foreign key - var associationType = reflectType.Name() - var toFields = toScope.GetStructFields() - relationship.Kind = "has_many" - - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - // Dog has many toys, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('dogs') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - } - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+field.Name) - associationForeignKeys = append(associationForeignKeys, field.Name) - } - } else { - // generate foreign keys from defined association foreign keys - for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // source foreign keys - foreignField.IsForeignKey = true - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - }(field) - case reflect.Struct: - defer func(field *StructField) { - var ( - // user has one profile, associationType is User, profile use UserID as foreign key - // user belongs to profile, associationType is Profile, user use ProfileID as foreign key - associationType = reflectType.Name() - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - toFields = toScope.GetStructFields() - tagForeignKeys []string - tagAssociationForeignKeys []string - ) - - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") - } - - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") - } - - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - // Cat has one toy, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('cats') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - } - } - - // Has One - { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys form association foreign keys - for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { - foreignField.IsForeignKey = true - // source foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" - field.Relationship = relationship - } else { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - - if len(foreignKeys) == 0 { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, primaryField := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys with association foreign keys - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - foreignKeys = append(foreignKeys, field.Name+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{toScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { - foreignField.IsForeignKey = true - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // source foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" - field.Relationship = relationship - } - } - }(field) - default: - field.IsNormal = true - } - } - } - - // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettings["COLUMN"]; ok { - field.DBName = value - } else { - field.DBName = ToDBName(fieldStruct.Name) - } - - modelStruct.StructFields = append(modelStruct.StructFields, field) - } - } - - if len(modelStruct.PrimaryFields) == 0 { - if field := getForeignField("id", modelStruct.StructFields); field != nil { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } - - modelStructsMap.Set(reflectType, &modelStruct) - - return &modelStruct -} - -// GetStructFields get model's field structs -func (scope *Scope) GetStructFields() (fields []*StructField) { - return scope.GetModelStruct().StructFields -} - -func parseTagSetting(tags reflect.StructTag) map[string]string { - setting := map[string]string{} - for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { - tags := strings.Split(str, ";") - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k - } - } - } - return setting -} diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go deleted file mode 100644 index 8b275d18..00000000 --- a/multi_primary_keys_test.go +++ /dev/null @@ -1,381 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "sort" - "testing" -) - -type Blog struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Subject string - Body string - Tags []Tag `gorm:"many2many:blog_tags;"` - SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` - LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` -} - -type Tag struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Value string - Blogs []*Blog `gorm:"many2many:blogs_tags"` -} - -func compareTags(tags []Tag, contents []string) bool { - var tagContents []string - for _, tag := range tags { - tagContents = append(tagContents, tag.Value) - } - sort.Strings(tagContents) - sort.Strings(contents) - return reflect.DeepEqual(tagContents, contents) -} - -func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - Tags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - - DB.Save(&blog) - if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) - if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("Tags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "Tags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("Tags").Find(&blog1) - if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog).Association("Tags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "Tags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("Tags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("Tags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "Tags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("Tags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog).Association("Tags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "Tags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog).Association("Tags").Clear() - if DB.Model(&blog).Association("Tags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("shared_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - SharedTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) - if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("SharedTags").Find(&blog1) - if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("SharedTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - DB.Model(&blog2).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("SharedTags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "SharedTags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog2).Association("SharedTags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "SharedTags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog2).Association("SharedTags").Clear() - if DB.Model(&blog).Association("SharedTags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("locale_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - LocaleTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) - if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog should has 0 tags after ZH Blog Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if len(tags) != 0 { - t.Errorf("Should find 0 tags with Related for EN Blog") - } - - var blog1 Blog - DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) - if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("LocaleTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related for EN Blog") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag4"}) { - t.Errorf("Should find 1 tags with Related for EN Blog") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) - - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - var blog11 Blog - DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) - if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - DB.Model(&blog2).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - var blog21 Blog - DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) - if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { - t.Errorf("EN Blog's tags should be changed after Replace") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Replace") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after Replace") - } - - // Delete - DB.Model(&blog).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") - } - - DB.Model(&blog2).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { - t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") - } - - // Clear - DB.Model(&blog2).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") - } - - DB.Model(&blog).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 0 { - t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared") - } - } -} diff --git a/pointer_test.go b/pointer_test.go deleted file mode 100644 index 2a68a5ab..00000000 --- a/pointer_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package gorm_test - -import "testing" - -type PointerStruct struct { - ID int64 - Name *string - Num *int -} - -type NormalStruct struct { - ID int64 - Name string - Num int -} - -func TestPointerFields(t *testing.T) { - DB.DropTable(&PointerStruct{}) - DB.AutoMigrate(&PointerStruct{}) - var name = "pointer struct 1" - var num = 100 - pointerStruct := PointerStruct{Name: &name, Num: &num} - if DB.Create(&pointerStruct).Error != nil { - t.Errorf("Failed to save pointer struct") - } - - var pointerStructResult PointerStruct - if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { - t.Errorf("Failed to query saved pointer struct") - } - - var tableName = DB.NewScope(&PointerStruct{}).TableName() - - var normalStruct NormalStruct - DB.Table(tableName).First(&normalStruct) - if normalStruct.Name != name || normalStruct.Num != num { - t.Errorf("Failed to query saved Normal struct") - } - - var nilPointerStruct = PointerStruct{} - if err := DB.Create(&nilPointerStruct).Error; err != nil { - t.Error("Failed to save nil pointer struct", err) - } - - var pointerStruct2 PointerStruct - if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var normalStruct2 NormalStruct - if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var partialNilPointerStruct1 = PointerStruct{Num: &num} - if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct3 PointerStruct - if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct3 NormalStruct - if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { - t.Error("Failed to query saved partial pointer struct", err) - } - - var partialNilPointerStruct2 = PointerStruct{Name: &name} - if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct4 PointerStruct - if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct4 NormalStruct - if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { - t.Error("Failed to query saved partial pointer struct", err) - } -} diff --git a/polymorphic_test.go b/polymorphic_test.go deleted file mode 100644 index df573f97..00000000 --- a/polymorphic_test.go +++ /dev/null @@ -1,219 +0,0 @@ -package gorm_test - -import ( - "reflect" - "sort" - "testing" -) - -type Cat struct { - Id int - Name string - Toy Toy `gorm:"polymorphic:Owner;"` -} - -type Dog struct { - Id int - Name string - Toys []Toy `gorm:"polymorphic:Owner;"` -} - -type Toy struct { - Id int - Name string - OwnerId int - OwnerType string -} - -var compareToys = func(toys []Toy, contents []string) bool { - var toyContents []string - for _, toy := range toys { - toyContents = append(toyContents, toy.Name) - } - sort.Strings(toyContents) - sort.Strings(contents) - return reflect.DeepEqual(toyContents, contents) -} - -func TestPolymorphic(t *testing.T) { - cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} - dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}} - DB.Save(&cat).Save(&dog) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Dog's toys count should be 2") - } - - // Query - var catToys []Toy - if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(catToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if catToys[0].Name != cat.Toy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - var dogToys []Toy - if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() { - t.Errorf("Did not find any polymorphic has many associations") - } else if len(dogToys) != len(dog.Toys) { - t.Errorf("Should have found all polymorphic has many associations") - } - - var catToy Toy - DB.Model(&cat).Association("Toy").Find(&catToy) - if catToy.Name != cat.Toy.Name { - t.Errorf("Should find has one polymorphic association") - } - - var dogToys1 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys1) - if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) { - t.Errorf("Should find has many polymorphic association") - } - - // Append - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - var catToy2 Toy - DB.Model(&cat).Association("Toy").Find(&catToy2) - if catToy2.Name != "cat toy 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Should return two polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 3", - }) - - var dogToys2 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys2) - if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) { - t.Errorf("Dog's toys should be updated with Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Replace - DB.Model(&cat).Association("Toy").Replace(&Toy{ - Name: "cat toy 3", - }) - - var catToy3 Toy - DB.Model(&cat).Association("Toy").Find(&catToy3) - if catToy3.Name != "cat toy 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Replace(&Toy{ - Name: "dog toy 4", - }, []Toy{ - {Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"}, - }) - - var dogToys3 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys3) - if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) { - t.Errorf("Dog's toys should be updated with Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Delete - DB.Model(&cat).Association("Toy").Delete(&catToy2) - - var catToy4 Toy - DB.Model(&cat).Association("Toy").Find(&catToy4) - if catToy4.Name != "cat toy 3" { - t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should be 4") - } - - DB.Model(&cat).Association("Toy").Delete(&catToy3) - - if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() { - t.Errorf("Toy should be deleted with Delete") - } - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys count should be 0 after Delete") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete cat's toy") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys2) - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete unrelated toys") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys3) - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys count should be deleted with Delete") - } - - // Clear - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys should be added with Append") - } - - DB.Model(&cat).Association("Toy").Clear() - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys should be cleared with Clear") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 8", - }) - - if DB.Model(&dog).Association("Toys").Count() != 1 { - t.Errorf("Dog's toys should be added with Append") - } - - DB.Model(&dog).Association("Toys").Clear() - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys should be cleared with Clear") - } -} diff --git a/prepare_stmt.go b/prepare_stmt.go new file mode 100644 index 00000000..e09fe814 --- /dev/null +++ b/prepare_stmt.go @@ -0,0 +1,215 @@ +package gorm + +import ( + "context" + "database/sql" + "sync" +) + +type Stmt struct { + *sql.Stmt + Transaction bool + prepared chan struct{} + prepareErr error +} + +type PreparedStmtDB struct { + Stmts map[string]*Stmt + PreparedSQL []string + Mux *sync.RWMutex + ConnPool +} + +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + + if sqldb, ok := db.ConnPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, ErrInvalidDB +} + +func (db *PreparedStmtDB) Close() { + db.Mux.Lock() + defer db.Mux.Unlock() + + for _, query := range db.PreparedSQL { + if stmt, ok := db.Stmts[query]; ok { + delete(db.Stmts, query) + go stmt.Close() + } + } +} + +func (db *PreparedStmtDB) Reset() { + db.Mux.Lock() + defer db.Mux.Unlock() + + for _, stmt := range db.Stmts { + go stmt.Close() + } + db.PreparedSQL = make([]string, 0, 100) + db.Stmts = make(map[string]*Stmt) +} + +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { + db.Mux.RLock() + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { + db.Mux.RUnlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil + } + db.Mux.RUnlock() + + db.Mux.Lock() + // double check + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { + db.Mux.Unlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil + } + + // cache preparing stmt first + cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} + db.Stmts[query] = &cacheStmt + db.Mux.Unlock() + + // prepare completed + defer close(cacheStmt.prepared) + + // Reason why cannot lock conn.PrepareContext + // suppose the maxopen is 1, g1 is creating record and g2 is querying record. + // 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. + // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. + // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. + stmt, err := conn.PrepareContext(ctx, query) + if err != nil { + cacheStmt.prepareErr = err + db.Mux.Lock() + delete(db.Stmts, query) + db.Mux.Unlock() + return Stmt{}, err + } + + db.Mux.Lock() + cacheStmt.Stmt = stmt + db.PreparedSQL = append(db.PreparedSQL, query) + db.Mux.Unlock() + + return cacheStmt, nil +} + +func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { + if beginner, ok := db.ConnPool.(TxBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } + return nil, ErrInvalidTransaction +} + +func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { + stmt, err := db.prepare(ctx, db.ConnPool, false, query) + if err == nil { + result, err = stmt.ExecContext(ctx, args...) + if err != nil { + db.Mux.Lock() + defer db.Mux.Unlock() + go stmt.Close() + delete(db.Stmts, query) + } + } + return result, err +} + +func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { + stmt, err := db.prepare(ctx, db.ConnPool, false, query) + if err == nil { + rows, err = stmt.QueryContext(ctx, args...) + if err != nil { + db.Mux.Lock() + defer db.Mux.Unlock() + + go stmt.Close() + delete(db.Stmts, query) + } + } + return rows, err +} + +func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := db.prepare(ctx, db.ConnPool, false, query) + if err == nil { + return stmt.QueryRowContext(ctx, args...) + } + return &sql.Row{} +} + +type PreparedStmtTX struct { + Tx + PreparedStmtDB *PreparedStmtDB +} + +func (tx *PreparedStmtTX) Commit() error { + if tx.Tx != nil { + return tx.Tx.Commit() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) Rollback() error { + if tx.Tx != nil { + return tx.Tx.Rollback() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) + if err == nil { + result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.Mux.Lock() + defer tx.PreparedStmtDB.Mux.Unlock() + + go stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + } + } + return result, err +} + +func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) + if err == nil { + rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.Mux.Lock() + defer tx.PreparedStmtDB.Mux.Unlock() + + go stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + } + } + return rows, err +} + +func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) + if err == nil { + return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) + } + return &sql.Row{} +} diff --git a/query_test.go b/query_test.go deleted file mode 100644 index 1a500465..00000000 --- a/query_test.go +++ /dev/null @@ -1,644 +0,0 @@ -package gorm_test - -import ( - "fmt" - "reflect" - - "github.com/jinzhu/gorm" - "github.com/jinzhu/now" - - "testing" - "time" -) - -func TestFirstAndLast(t *testing.T) { - DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}}) - DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}}) - - var user1, user2, user3, user4 User - DB.First(&user1) - DB.Order("id").Limit(1).Find(&user2) - - DB.Last(&user3) - DB.Order("id desc").Limit(1).Find(&user4) - if user1.Id != user2.Id || user3.Id != user4.Id { - t.Errorf("First and Last should by order by primary key") - } - - var users []User - DB.First(&users) - if len(users) != 1 { - t.Errorf("Find first record as slice") - } - - var user User - if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil { - t.Errorf("Should not raise any error when order with Join table") - } - - if user.Email != "" { - t.Errorf("User's Email should be blank as no one set it") - } -} - -func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { - DB.Save(&Animal{Name: "animal1"}) - DB.Save(&Animal{Name: "animal2"}) - - var animal1, animal2, animal3, animal4 Animal - DB.First(&animal1) - DB.Order("counter").Limit(1).Find(&animal2) - - DB.Last(&animal3) - DB.Order("counter desc").Limit(1).Find(&animal4) - if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { - t.Errorf("First and Last should work correctly") - } -} - -func TestUIntPrimaryKey(t *testing.T) { - var animal Animal - DB.First(&animal, uint64(1)) - if animal.Counter != 1 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } - - DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) - if animal.Counter != 2 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } -} - -func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { - type AddressByZipCode struct { - ZipCode string `gorm:"primary_key"` - Address string - } - - DB.AutoMigrate(&AddressByZipCode{}) - DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"}) - - var address AddressByZipCode - DB.First(&address, "00501") - if address.ZipCode != "00501" { - t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed") - } -} - -func TestFindAsSliceOfPointers(t *testing.T) { - DB.Save(&User{Name: "user"}) - - var users []User - DB.Find(&users) - - var userPointers []*User - DB.Find(&userPointers) - - if len(users) == 0 || len(users) != len(userPointers) { - t.Errorf("Find slice of pointers") - } -} - -func TestSearchWithPlainSQL(t *testing.T) { - user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%") - - if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("Search with plain SQL") - } - - if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() { - t.Errorf("Search with plan SQL (regexp)") - } - - var users []User - DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1) - if len(users) != 2 { - t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) - } - - DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) - } - - scopedb.Where("age <> ?", 20).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users age != 20, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", "2002-10-10").Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) - } - - scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) - } - - DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users, but got %v", len(users)) - } - - DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users, but got %v", len(users)) - } - - DB.Where("id in (?)", user1.Id).Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users, but got %v", len(users)) - } - - if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { - t.Errorf("Should not get RecordNotFound error when looking for none existing records") - } -} - -func TestSearchWithStruct(t *testing.T) { - user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where(user1.Id).First(&User{}).RecordNotFound() { - t.Errorf("Search with primary key") - } - - if DB.First(&User{}, user1.Id).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - var users []User - DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users)) - } - - var user User - DB.First(&user, &User{Name: user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline pointer of struct") - } - - DB.First(&user, User{Name: user1.Name}) - if user.Id == 0 || user.Name != user.Name { - t.Errorf("Search first record with inline struct") - } - - DB.Where(&User{Name: user1.Name}).First(&user) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with where struct") - } - - DB.Find(&users, &User{Name: user2.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline struct") - } -} - -func TestSearchWithMap(t *testing.T) { - companyID := 1 - user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), CompanyID: &companyID} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) - - var user User - DB.First(&user, map[string]interface{}{"name": user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline map") - } - - user = User{} - DB.Where(map[string]interface{}{"name": user2.Name}).First(&user) - if user.Id == 0 || user.Name != user2.Name { - t.Errorf("Search first record with where map") - } - - var users []User - DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user3.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) - if len(users) != 0 { - t.Errorf("Search all records with inline map containing null value finding 0 records") - } - - DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) - if len(users) != 1 { - t.Errorf("Search all records with inline map containing null value finding 1 record") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) - if len(users) != 1 { - t.Errorf("Search all records with inline multiple value map") - } -} - -func TestSearchWithEmptyChain(t *testing.T) { - user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} - user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} - user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where("").Where("").First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty strings") - } - - if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty struct") - } - - if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty map") - } -} - -func TestSelect(t *testing.T) { - user1 := User{Name: "SelectUser1"} - DB.Save(&user1) - - var user User - DB.Where("name = ?", user1.Name).Select("name").Find(&user) - if user.Id != 0 { - t.Errorf("Should not have ID because only selected name, %+v", user.Id) - } - - if user.Name != user1.Name { - t.Errorf("Should have user Name when selected it") - } -} - -func TestOrderAndPluck(t *testing.T) { - user1 := User{Name: "OrderPluckUser1", Age: 1} - user2 := User{Name: "OrderPluckUser2", Age: 10} - user3 := User{Name: "OrderPluckUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") - - var user User - scopedb.Order(gorm.Expr("name = ? DESC", "OrderPluckUser2")).First(&user) - if user.Name != "OrderPluckUser2" { - t.Errorf("Order with sql expression") - } - - var ages []int64 - scopedb.Order("age desc").Pluck("age", &ages) - if ages[0] != 20 { - t.Errorf("The first age should be 20 when order with age desc") - } - - var ages1, ages2 []int64 - scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) - if !reflect.DeepEqual(ages1, ages2) { - t.Errorf("The first order is the primary order") - } - - var ages3, ages4 []int64 - scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) - if reflect.DeepEqual(ages3, ages4) { - t.Errorf("Reorder should work") - } - - var names []string - var ages5 []int64 - scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) - if names != nil && ages5 != nil { - if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { - t.Errorf("Order with multiple orders") - } - } else { - t.Errorf("Order with multiple orders") - } - - DB.Model(User{}).Select("name, age").Find(&[]User{}) -} - -func TestLimit(t *testing.T) { - user1 := User{Name: "LimitUser1", Age: 1} - user2 := User{Name: "LimitUser2", Age: 10} - user3 := User{Name: "LimitUser3", Age: 20} - user4 := User{Name: "LimitUser4", Age: 10} - user5 := User{Name: "LimitUser5", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5) - - var users1, users2, users3 []User - DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) - - if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") - } -} - -func TestOffset(t *testing.T) { - for i := 0; i < 20; i++ { - DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) - } - var users1, users2, users3, users4 []User - DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) - - if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { - t.Errorf("Offset should work") - } -} - -func TestOr(t *testing.T) { - user1 := User{Name: "OrUser1", Age: 1} - user2 := User{Name: "OrUser2", Age: 10} - user3 := User{Name: "OrUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users []User - DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users) - if len(users) != 2 { - t.Errorf("Find users with or") - } -} - -func TestCount(t *testing.T) { - user1 := User{Name: "CountUser1", Age: 1} - user2 := User{Name: "CountUser2", Age: 10} - user3 := User{Name: "CountUser3", Age: 20} - - DB.Save(&user1).Save(&user2).Save(&user3) - var count, count1, count2 int64 - var users []User - - if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { - t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) - } - - if count != int64(len(users)) { - t.Errorf("Count() method should get correct value") - } - - DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2) - if count1 != 1 || count2 != 3 { - t.Errorf("Multiple count in chain") - } -} - -func TestNot(t *testing.T) { - DB.Create(getPreparedUser("user1", "not")) - DB.Create(getPreparedUser("user2", "not")) - DB.Create(getPreparedUser("user3", "not")) - - user4 := getPreparedUser("user4", "not") - user4.Company = Company{} - DB.Create(user4) - - DB := DB.Where("role = ?", "not") - - var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User - if DB.Find(&users1).RowsAffected != 4 { - t.Errorf("should find 4 not users") - } - DB.Not(users1[0].Id).Find(&users2) - - if len(users1)-len(users2) != 1 { - t.Errorf("Should ignore the first users with Not") - } - - DB.Not([]int{}).Find(&users3) - if len(users1)-len(users3) != 0 { - t.Errorf("Should find all users with a blank condition") - } - - var name3Count int64 - DB.Table("users").Where("name = ?", "user3").Count(&name3Count) - DB.Not("name", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not("name = ?", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not("name <> ?", "user3").Find(&users4) - if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not(User{Name: "user3"}).Find(&users5) - - if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) - if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) - if len(users1)-len(users7) != 2 { // not user3 or user4 - t.Errorf("Should find all user's name not equal to 3 who do not have company id") - } - - DB.Not("name", []string{"user3"}).Find(&users8) - if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - var name2Count int64 - DB.Table("users").Where("name = ?", "user2").Count(&name2Count) - DB.Not("name", []string{"user3", "user2"}).Find(&users9) - if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") - } -} - -func TestFillSmallerStruct(t *testing.T) { - user1 := User{Name: "SmallerUser", Age: 100} - DB.Save(&user1) - type SimpleUser struct { - Name string - Id int64 - UpdatedAt time.Time - CreatedAt time.Time - } - - var simpleUser SimpleUser - DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser) - - if simpleUser.Id == 0 || simpleUser.Name == "" { - t.Errorf("Should fill data correctly into smaller struct") - } -} - -func TestFindOrInitialize(t *testing.T) { - var user1, user2, user3, user4, user5, user6 User - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1) - if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) - if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) - if user3.Name != "find or init 2" || user3.Id != 0 { - t.Errorf("user should be initialized with inline search value") - } - - DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and attrs") - } - - DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and assign attrs") - } - - DB.Save(&User{Name: "find or init", Age: 33}) - DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 { - t.Errorf("user should be found with FirstOrInit") - } - - DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } -} - -func TestFindOrCreate(t *testing.T) { - var user1, user2, user3, user4, user5, user6, user7, user8 User - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) - if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) - if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) - if user3.Name != "find or create 2" || user3.Id == 0 { - t.Errorf("user should be created with inline search value") - } - - DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) - if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and attrs") - } - - updatedAt1 := user4.UpdatedAt - DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) - if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateAt should be changed when update values with assign") - } - - DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) - if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) - if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Find(&user7) - if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8) - if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() { - t.Errorf("embedded struct email should be saved") - } - - if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() { - t.Errorf("embedded struct credit card should be saved") - } -} - -func TestSelectWithEscapedFieldName(t *testing.T) { - user1 := User{Name: "EscapedFieldNameUser", Age: 1} - user2 := User{Name: "EscapedFieldNameUser", Age: 10} - user3 := User{Name: "EscapedFieldNameUser", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var names []string - DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names) - - if len(names) != 3 { - t.Errorf("Expected 3 name, but got: %d", len(names)) - } -} - -func TestSelectWithVariables(t *testing.T) { - DB.Save(&User{Name: "jinzhu"}) - - rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() - - if !rows.Next() { - t.Errorf("Should have returned at least one row") - } else { - columns, _ := rows.Columns() - if !reflect.DeepEqual(columns, []string{"fake"}) { - t.Errorf("Should only contains one column") - } - } - - rows.Close() -} - -func TestSelectWithArrayInput(t *testing.T) { - DB.Save(&User{Name: "jinzhu", Age: 42}) - - var user User - DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) - - if user.Name != "jinzhu" || user.Age != 42 { - t.Errorf("Should have selected both age and name") - } -} diff --git a/scan.go b/scan.go new file mode 100644 index 00000000..736db4d3 --- /dev/null +++ b/scan.go @@ -0,0 +1,342 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "reflect" + "time" + + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// prepareValues prepare values slice +func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { + if db.Statement.Schema != nil { + for idx, name := range columns { + if field := db.Statement.Schema.LookUpField(name); field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + values[idx] = new(interface{}) + } + } else if len(columnTypes) > 0 { + for idx, columnType := range columnTypes { + if columnType.ScanType() != nil { + values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() + } else { + values[idx] = new(interface{}) + } + } + } else { + for idx := range columns { + values[idx] = new(interface{}) + } + } +} + +func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { + for idx, column := range columns { + if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { + mapValue[column] = reflectValue.Interface() + if valuer, ok := mapValue[column].(driver.Valuer); ok { + mapValue[column], _ = valuer.Value() + } else if b, ok := mapValue[column].(sql.RawBytes); ok { + mapValue[column] = string(b) + } + } else { + mapValue[column] = nil + } + } +} + +func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) { + for idx, field := range fields { + if field != nil { + values[idx] = field.NewValuePool.Get() + } else if len(fields) == 1 { + if reflectValue.CanAddr() { + values[idx] = reflectValue.Addr().Interface() + } else { + values[idx] = reflectValue.Interface() + } + } + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + joinedNestedSchemaMap := make(map[string]interface{}) + for idx, field := range fields { + if field == nil { + continue + } + + if len(joinFields) == 0 || len(joinFields[idx]) == 0 { + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } else { // joinFields count is larger than 2 when using join + var isNilPtrValue bool + var relValue reflect.Value + // does not contain raw dbname + nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1] + // current reflect value + currentReflectValue := reflectValue + fullRels := make([]string, 0, len(nestedJoinSchemas)) + for _, joinSchema := range nestedJoinSchemas { + fullRels = append(fullRels, joinSchema.Name) + relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue) + if relValue.Kind() == reflect.Ptr { + fullRelsName := utils.JoinNestedRelationNames(fullRels) + // same nested structure + if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + isNilPtrValue = true + break + } + + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedNestedSchemaMap[fullRelsName] = nil + } + } + currentReflectValue = relValue + } + + if !isNilPtrValue { // ignore if value is nil + f := joinFields[idx][len(joinFields[idx])-1] + db.AddError(f.Set(db.Statement.Context, relValue, values[idx])) + } + } + + // release data to pool + field.NewValuePool.Put(values[idx]) + } +} + +// ScanMode scan data mode +type ScanMode uint8 + +// scan modes +const ( + ScanInitialized ScanMode = 1 << 0 // 1 + ScanUpdate ScanMode = 1 << 1 // 2 + ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 +) + +// Scan scan rows into db statement +func Scan(rows Rows, db *DB, mode ScanMode) { + var ( + columns, _ = rows.Columns() + values = make([]interface{}, len(columns)) + initialized = mode&ScanInitialized != 0 + update = mode&ScanUpdate != 0 + onConflictDonothing = mode&ScanOnConflictDoNothing != 0 + ) + + db.RowsAffected = 0 + + switch dest := db.Statement.Dest.(type) { + case map[string]interface{}, *map[string]interface{}: + if initialized || rows.Next() { + columnTypes, _ := rows.ColumnTypes() + prepareValues(values, db, columnTypes, columns) + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + mapValue, ok := dest.(map[string]interface{}) + if !ok { + if v, ok := dest.(*map[string]interface{}); ok { + if *v == nil { + *v = map[string]interface{}{} + } + mapValue = *v + } + } + scanIntoMap(mapValue, values, columns) + } + case *[]map[string]interface{}: + columnTypes, _ := rows.ColumnTypes() + for initialized || rows.Next() { + prepareValues(values, db, columnTypes, columns) + + initialized = false + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + mapValue := map[string]interface{}{} + scanIntoMap(mapValue, values, columns) + *dest = append(*dest, mapValue) + } + case *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, + *float32, *float64, + *bool, *string, *time.Time, + *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, + *sql.NullBool, *sql.NullString, *sql.NullTime: + for initialized || rows.Next() { + initialized = false + db.RowsAffected++ + db.AddError(rows.Scan(dest)) + } + default: + var ( + fields = make([]*schema.Field, len(columns)) + joinFields [][]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue + ) + + if reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } + + reflectValueType := reflectValue.Type() + switch reflectValueType.Kind() { + case reflect.Array, reflect.Slice: + reflectValueType = reflectValueType.Elem() + } + isPtr := reflectValueType.Kind() == reflect.Ptr + if isPtr { + reflectValueType = reflectValueType.Elem() + } + + if sch != nil { + if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { + sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + + if len(columns) == 1 { + // Is Pluck + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner + reflectValueType.Kind() != reflect.Struct || // is not struct + sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time + sch = nil + } + } + + // Not Pluck + if sch != nil { + matchedFieldCount := make(map[string]int, len(columns)) + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + if count, ok := matchedFieldCount[column]; ok { + // handle duplicate fields + for _, selectField := range sch.Fields { + if selectField.DBName == column && selectField.Readable { + if count == 0 { + matchedFieldCount[column]++ + fields[idx] = selectField + break + } + count-- + } + } + } else { + matchedFieldCount[column] = 1 + } + } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + subNameCount := len(names) + // nested relation fields + relFields := make([]*schema.Field, 0, subNameCount-1) + relFields = append(relFields, rel.Field) + for _, name := range names[1 : subNameCount-1] { + rel = rel.FieldSchema.Relationships.Relations[name] + relFields = append(relFields, rel.Field) + } + // lastest name is raw dbname + dbName := names[subNameCount-1] + if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][]*schema.Field, len(columns)) + } + relFields = append(relFields, field) + joinFields[idx] = relFields + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + elem reflect.Value + isArrayKind = reflectValue.Kind() == reflect.Array + ) + + if !update || reflectValue.Len() == 0 { + update = false + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else if !isArrayKind { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } + } + + for initialized || rows.Next() { + BEGIN: + initialized = false + + if update { + if int(db.RowsAffected) >= reflectValue.Len() { + return + } + elem = reflectValue.Index(int(db.RowsAffected)) + if onConflictDonothing { + for _, field := range fields { + if _, ok := field.ValueOf(db.Statement.Context, elem); !ok { + db.RowsAffected++ + goto BEGIN + } + } + } + } else { + elem = reflect.New(reflectValueType) + } + + db.scanIntoStruct(rows, elem, values, fields, joinFields) + + if !update { + if !isPtr { + elem = elem.Elem() + } + if isArrayKind { + if reflectValue.Len() >= int(db.RowsAffected) { + reflectValue.Index(int(db.RowsAffected - 1)).Set(elem) + } + } else { + reflectValue = reflect.Append(reflectValue, elem) + } + } + } + + if !update { + db.Statement.ReflectValue.Set(reflectValue) + } + case reflect.Struct, reflect.Ptr: + if initialized || rows.Next() { + db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) + } + default: + db.AddError(rows.Scan(dest)) + } + } + + if err := rows.Err(); err != nil && err != db.Error { + db.AddError(err) + } + + if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { + db.AddError(ErrRecordNotFound) + } +} diff --git a/scaner_test.go b/scaner_test.go deleted file mode 100644 index cd89ca49..00000000 --- a/scaner_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package gorm_test - -import ( - "database/sql/driver" - "encoding/json" - "errors" - "testing" -) - -func TestScannableSlices(t *testing.T) { - if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { - t.Errorf("Should create table with slice values correctly: %s", err) - } - - r1 := RecordWithSlice{ - Strings: ExampleStringSlice{"a", "b", "c"}, - Structs: ExampleStructSlice{ - {"name1", "value1"}, - {"name2", "value2"}, - }, - } - - if err := DB.Save(&r1).Error; err != nil { - t.Errorf("Should save record with slice values") - } - - var r2 RecordWithSlice - - if err := DB.Find(&r2).Error; err != nil { - t.Errorf("Should fetch record with slice values") - } - - if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { - t.Errorf("Should have serialised and deserialised a string array") - } - - if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { - t.Errorf("Should have serialised and deserialised a struct array") - } -} - -type RecordWithSlice struct { - ID uint64 - Strings ExampleStringSlice `sql:"type:text"` - Structs ExampleStructSlice `sql:"type:text"` -} - -type ExampleStringSlice []string - -func (l ExampleStringSlice) Value() (driver.Value, error) { - return json.Marshal(l) -} - -func (l *ExampleStringSlice) Scan(input interface{}) error { - switch value := input.(type) { - case string: - return json.Unmarshal([]byte(value), l) - case []byte: - return json.Unmarshal(value, l) - default: - return errors.New("not supported") - } -} - -type ExampleStruct struct { - Name string - Value string -} - -type ExampleStructSlice []ExampleStruct - -func (l ExampleStructSlice) Value() (driver.Value, error) { - return json.Marshal(l) -} - -func (l *ExampleStructSlice) Scan(input interface{}) error { - switch value := input.(type) { - case string: - return json.Unmarshal([]byte(value), l) - case []byte: - return json.Unmarshal(value, l) - default: - return errors.New("not supported") - } -} diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go new file mode 100644 index 00000000..4583a207 --- /dev/null +++ b/schema/callbacks_test.go @@ -0,0 +1,39 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +type UserWithCallback struct{} + +func (UserWithCallback) BeforeSave(*gorm.DB) error { + return nil +} + +func (UserWithCallback) AfterCreate(*gorm.DB) error { + return nil +} + +func TestCallback(t *testing.T) { + user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user with callback, got error %v", err) + } + + for _, str := range []string{"BeforeSave", "AfterCreate"} { + if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { + t.Errorf("%v should be true", str) + } + } + + for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} { + if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { + t.Errorf("%v should be false", str) + } + } +} diff --git a/schema/check.go b/schema/check.go new file mode 100644 index 00000000..89e732d3 --- /dev/null +++ b/schema/check.go @@ -0,0 +1,35 @@ +package schema + +import ( + "regexp" + "strings" +) + +// reg match english letters and midline +var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") + +type Check struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]Check { + checks := map[string]Check{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { + checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = Check{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} diff --git a/schema/check_test.go b/schema/check_test.go new file mode 100644 index 00000000..eda043b7 --- /dev/null +++ b/schema/check_test.go @@ -0,0 +1,55 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "gorm.io/gorm/schema" +) + +type UserCheck struct { + Name string `gorm:"check:name_checker,name <> 'jinzhu'"` + Name2 string `gorm:"check:name <> 'jinzhu'"` + Name3 string `gorm:"check:,name <> 'jinzhu'"` +} + +func TestParseCheck(t *testing.T) { + user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user check, got error %v", err) + } + + results := map[string]schema.Check{ + "name_checker": { + Name: "name_checker", + Constraint: "name <> 'jinzhu'", + }, + "chk_user_checks_name2": { + Name: "chk_user_checks_name2", + Constraint: "name <> 'jinzhu'", + }, + "chk_user_checks_name3": { + Name: "chk_user_checks_name3", + Constraint: "name <> 'jinzhu'", + }, + } + + checks := user.ParseCheckConstraints() + + for k, result := range results { + v, ok := checks[k] + if !ok { + t.Errorf("Failed to found check %v from parsed checks %+v", k, checks) + } + + for _, name := range []string{"Name", "Constraint"} { + if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { + t.Errorf( + "check %v %v should equal, expects %v, got %v", + k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), + ) + } + } + } +} diff --git a/schema/field.go b/schema/field.go new file mode 100644 index 00000000..50fe8da1 --- /dev/null +++ b/schema/field.go @@ -0,0 +1,1014 @@ +package schema + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/jinzhu/now" + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +// special types' reflect type +var ( + TimeReflectType = reflect.TypeOf(time.Time{}) + TimePtrReflectType = reflect.TypeOf(&time.Time{}) + ByteReflectType = reflect.TypeOf(uint8(0)) +) + +type ( + // DataType GORM data type + DataType string + // TimeType GORM time type + TimeType int64 +) + +// GORM time types +const ( + UnixTime TimeType = 1 + UnixSecond TimeType = 2 + UnixMillisecond TimeType = 3 + UnixNanosecond TimeType = 4 +) + +// GORM fields types +const ( + Bool DataType = "bool" + Int DataType = "int" + Uint DataType = "uint" + Float DataType = "float" + String DataType = "string" + Time DataType = "time" + Bytes DataType = "bytes" +) + +// Field is the representation of model schema's field +type Field struct { + Name string // 结构体的名字 + DBName string // 结构体对应的 db COLUMN 名字 + BindNames []string // 带结构体层级的 Name, 然后是嵌套结构体,倒数第一个值是字段名,上一个值是上级结构体名 + DataType DataType // 表示数据库字段类型 + GORMDataType DataType // 用于处理数据库字段类型和 Golang 类型之间映射 + PrimaryKey bool // 该字段是否是主键 + AutoIncrement bool // 该字段是否自增 + AutoIncrementIncrement int64 // 自增开始值,用 AUTOINCREMENTINCREMENT 注解定义 + Creatable bool // 创建的时候可见 + Updatable bool // 更新的时候可见 + Readable bool // 读取的时候可见 + AutoCreateTime TimeType // 在创建的时候自动设置创建时间,及其设置形式 + AutoUpdateTime TimeType // 在创建和更新的时候自动设置更新时间,及其设置形式 + HasDefaultValue bool // 该字段是否有默认值,带有 default 注解,或者是自增的注解 + DefaultValue string // 该字段的默认值 + DefaultValueInterface interface{} // 解析后的默认值 + NotNull bool // 是否是 NOT NULL + Unique bool // 是否是唯一的 + Comment string // 表字段注释 + Size int // 字段的大小 + Precision int // 精度 + Scale int // 小数位数的精度 + IgnoreMigration bool // migration 时忽略该字段 + FieldType reflect.Type // 字段的类型,可能是指针 + IndirectFieldType reflect.Type // 字段的真实类型 + StructField reflect.StructField // 从当前字段所属结构体里面取出来的字段定义,如果是嵌套结构体,则 Index 会有多层 + Tag reflect.StructTag // 字段的 tag + TagSettings map[string]string // 从字段 gorm 注解里面解析出来的配置 + Schema *Schema // 字段所属的 model 结构体的 schema, (最外层) + EmbeddedSchema *Schema // 如果当前字段是一个嵌套结构体,其 Schema 保存在这里 + OwnerSchema *Schema // 嵌入的结构体解析出来的 Schema + ReflectValueOf func(context.Context, reflect.Value) reflect.Value + // 该方法返回当前字段的 interface 值和是否是 zero, 如果当前 字段定义是嵌套结构体,会返回嵌套结构体的 Value + ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) + Set func(context.Context, reflect.Value, interface{}) error + Serializer SerializerInterface // 该字段配置的序列化器 + NewValuePool FieldNewValuePool +} + +func (field *Field) BindName() string { + return strings.Join(field.BindNames, ".") +} + +// ParseField parses reflect.StructField to Field +func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { + var ( + err error + tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";") // 解析当前字段的 gorm 注解到 tagSetting map 里面 + ) + + field := &Field{ + Name: fieldStruct.Name, + DBName: tagSetting["COLUMN"], + BindNames: []string{fieldStruct.Name}, + FieldType: fieldStruct.Type, + IndirectFieldType: fieldStruct.Type, + StructField: fieldStruct, + Tag: fieldStruct.Tag, + TagSettings: tagSetting, + Schema: schema, + Creatable: true, + Updatable: true, + Readable: true, + PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), + AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), + Unique: utils.CheckTruth(tagSetting["UNIQUE"]), + Comment: tagSetting["COMMENT"], + AutoIncrementIncrement: 1, + } + + for field.IndirectFieldType.Kind() == reflect.Ptr { // 如果字段是指针,会通过 Elem 拿到实际类型 + field.IndirectFieldType = field.IndirectFieldType.Elem() + } + + fieldValue := reflect.New(field.IndirectFieldType) // 创建一个实际类型实例 + // if field is valuer, used its value or first field as data type + valuer, isValuer := fieldValue.Interface().(driver.Valuer) + if isValuer { // 如果实现了 driver.Valuer 接口 + if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { + if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { + fieldValue = reflect.ValueOf(v) // 如果没有实现 GormDataTypeInterface, 则当做 driver.Valuer 对待,调用 Value() 方法,获取 value + } + + // Use the field struct's first field type as data type, e.g: use `string` for sql.NullString + var getRealFieldValue func(reflect.Value) + getRealFieldValue = func(v reflect.Value) { + var ( + rv = reflect.Indirect(v) + rvType = rv.Type() + ) + + if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) { // 如果当前值是结构体,并且不能被转换为 time.Time + for i := 0; i < rvType.NumField(); i++ { + for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value // 解析结构体的所有字段的 gorm 注解,添加到 field.TagSettings 里面 + } + } + } + + for i := 0; i < rvType.NumField(); i++ { + newFieldType := rvType.Field(i).Type + for newFieldType.Kind() == reflect.Ptr { + newFieldType = newFieldType.Elem() + } // 如果该类型是指针,取出实际类型 + + fieldValue = reflect.New(newFieldType) + if rvType != reflect.Indirect(fieldValue).Type() { + getRealFieldValue(fieldValue) // 递归解析 + } + + if fieldValue.IsValid() { // 遇到第一个解析成功的类型,作为该字段类型 + return + } + } + } + } + + getRealFieldValue(fieldValue) + } + } + + if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { + field.DataType = String // 如果实现了 SerializerInterface 接口,则将字段的数据类型设置为 String + field.Serializer = v + } else { + serializerName := field.TagSettings["JSON"] + if serializerName == "" { + serializerName = field.TagSettings["SERIALIZER"] + } // SERIALIZER 注解优先级比 JSON 注解高 + if serializerName != "" { // 如果配置了 JSON 或者 SERIALIZER 注解 + if serializer, ok := GetSerializer(serializerName); ok { + // Set default data type to string for serializer + field.DataType = String // 从全局注册的序列化器中根据名字找到对应的序列化器 + field.Serializer = serializer + } else { // 找不到序列化器,报错 + schema.err = fmt.Errorf("invalid serializer type %v", serializerName) + } + } + } + + if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { // 设置了 AUTOINCREMENTINCREMENT 注解,指定了自增的起始值 + field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) + } + + if v, ok := field.TagSettings["DEFAULT"]; ok { + field.HasDefaultValue = true + field.DefaultValue = v // 配置了 DEFAULT 注解,设置默认值 + } + + if num, ok := field.TagSettings["SIZE"]; ok { + if field.Size, err = strconv.Atoi(num); err != nil { + field.Size = -1 // 配置了 SIZE 注解,设置 Size + } + } + + if p, ok := field.TagSettings["PRECISION"]; ok { + field.Precision, _ = strconv.Atoi(p) // 精度 + } + + if s, ok := field.TagSettings["SCALE"]; ok { + field.Scale, _ = strconv.Atoi(s) // 小数位数的精度 + } + + // default value is function or null or blank (primary keys) + field.DefaultValue = strings.TrimSpace(field.DefaultValue) + // 如果默认值包含 ( ), 或者是 null, "" , 不解析默认值 + skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" + switch reflect.Indirect(fieldValue).Kind() { + case reflect.Bool: + field.DataType = Bool + if field.HasDefaultValue && !skipParseDefaultValue { // 解析默认值到 DefaultValueInterface + if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { + schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err) + } + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.DataType = Int + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err) + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.DataType = Uint + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err) + } + } + case reflect.Float32, reflect.Float64: + field.DataType = Float + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err) + } + } + case reflect.String: + field.DataType = String + if field.HasDefaultValue && !skipParseDefaultValue { + field.DefaultValue = strings.Trim(field.DefaultValue, "'") + field.DefaultValue = strings.Trim(field.DefaultValue, `"`) + field.DefaultValueInterface = field.DefaultValue + } + case reflect.Struct: + if _, ok := fieldValue.Interface().(*time.Time); ok { // 各种形式的 time, 及其衍生类型 + field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { + field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { + field.DataType = Time + } + if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time { + if t, err := now.Parse(field.DefaultValue); err == nil { + field.DefaultValueInterface = t + } + } + case reflect.Array, reflect.Slice: + if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { + field.DataType = Bytes + } + } + + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + field.DataType = DataType(dataTyper.GormDataType()) // 如果实现 GormDataTypeInterface ,可指定 DataType + } + + // 以下情况会自动设置创建时间 + // 1. 带有 AUTOCREATETIME 注解, + // 2. 属性名叫做:CreatedAt 并且类型在 (Time, Int, Uint) 里面 + if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if field.DataType == Time { + field.AutoCreateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { + field.AutoCreateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoCreateTime = UnixMillisecond + } else { + field.AutoCreateTime = UnixSecond + } + } + + // 以下情况之一会在创建和更新的时候自动设置更新时间 + // 1. 带有 AUTOUPDATETIME 注解 + // 2. 名字为 UpdatedAt,并且类型在 (Time, Int, Uint) 里面 + if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if field.DataType == Time { + field.AutoUpdateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { + field.AutoUpdateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoUpdateTime = UnixMillisecond + } else { + field.AutoUpdateTime = UnixSecond + } + } + + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + + // 如果带了 TYPE 注解 + // 根据解析出来的 type 来设置 DataType + if val, ok := field.TagSettings["TYPE"]; ok { + switch DataType(strings.ToLower(val)) { + case Bool, Int, Uint, Float, String, Time, Bytes: + field.DataType = DataType(strings.ToLower(val)) + default: + field.DataType = DataType(val) + } + } + + if field.Size == 0 { // Size 没有设置, 根据数据类型生成 size + switch reflect.Indirect(fieldValue).Kind() { + case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: + field.Size = 64 + case reflect.Int8, reflect.Uint8: + field.Size = 8 + case reflect.Int16, reflect.Uint16: + field.Size = 16 + case reflect.Int32, reflect.Uint32, reflect.Float32: + field.Size = 32 + } + } + + // setup permission + if val, ok := field.TagSettings["-"]; ok { + val = strings.ToLower(strings.TrimSpace(val)) + switch val { + case "-": // 任何情况都忽略该字段 + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + case "all": // 任何情况都忽略该字段 + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + field.IgnoreMigration = true + case "migration": // 只在 migration 时忽略该字段 + field.IgnoreMigration = true + } + } + + if v, ok := field.TagSettings["->"]; ok { // 不可写,读取看配置 + field.Creatable = false + field.Updatable = false + if strings.ToLower(v) == "false" { + field.Readable = false + } else { + field.Readable = true + } + } + + if v, ok := field.TagSettings["<-"]; ok { // 配置先权限 + field.Creatable = true + field.Updatable = true + + if v != "<-" { + if !strings.Contains(v, "create") { // 不能创建 + field.Creatable = false + } + + if !strings.Contains(v, "update") { // 不能更新 + field.Updatable = false + } + } + } + + // Normal anonymous field or having `EMBEDDED` tag + // 以下情况之一会当做 EMBEDDED model, + // 1. 带有 EMBEDDED 注解 + // 2. 类型不为 (Time, Bytes), 并且没实现 driver.Valuer 接口,并且为嵌入字段,并且有(可创建,可更新,可读)权限之一) + if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer && + fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) { + kind := reflect.Indirect(fieldValue).Kind() + switch kind { + case reflect.Struct: // 如果是结构体,是嵌套结构 + var err error + // 后续操作忽略该字段 + field.Creatable = false + field.Updatable = false + field.Readable = false + + cacheStore := &sync.Map{} + cacheStore.Store(embeddedCacheKey, true) + // 解析该嵌入类型的 schema + if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { + schema.err = err + } + + for _, ef := range field.EmbeddedSchema.Fields { + ef.Schema = schema + ef.OwnerSchema = field.EmbeddedSchema + ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + // index is negative means is pointer + if field.FieldType.Kind() == reflect.Struct { + ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) + } else { // 嵌套的是一个指针 + ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) + } + + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" { + ef.DBName = prefix + ef.DBName // 如果定义了 EMBEDDEDPREFIX 注解,给 DBName 加一个前缀 + } + + if ef.PrimaryKey { + // 嵌套结构体被解析为主键(可能是名字叫 ID) + if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) { + // 只要不是显式有 PRIMARYKEY 注解,都不算注解 + ef.PrimaryKey = false + + // 没有显式定义 AUTOINCREMENT, 也不算自增 + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } + + // 由于 AUTOINCREMENT 会被当做有默认值,如果自增被取消了,这里的 HasDefaultValue 也要被取消 + if !ef.AutoIncrement && ef.DefaultValue == "" { + ef.HasDefaultValue = false + } + } + } + + for k, v := range field.TagSettings { + ef.TagSettings[k] = v // 嵌套结构体字段的 tag Setting 也会收集到嵌套结构体的 TagSetting 里面 + } + } + case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, + reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128: + schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) + } + } + + return field +} + +// create valuer, setter when parse struct +func (field *Field) setupValuerAndSetter() { + // Setup NewValuePool + field.setupNewValuePool() + + // ValueOf returns field's value and if it is zero + fieldIndex := field.StructField.Index[0] + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: // 非嵌套结构体场景 + field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(fieldIndex) + return fieldValue.Interface(), fieldValue.IsZero() + } + default: // 嵌套结构体 + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + // 嵌套结构体的 v 倒序存在 Index 里面 + for _, fieldIdx := range field.StructField.Index { + // 该字段是嵌套的, 传进来的 v 是最外层 model 结构体,Index 就是每一层对应的下标 + // 如果上一层是结构体 + if fieldIdx >= 0 { // 字段是一个结构体 + v = v.Field(fieldIdx) + } else { // 如果上一层是一个指针 + v = v.Field(-fieldIdx - 1) + + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true + } + } + } + + fv, zero := v.Interface(), v.IsZero() + return fv, zero + } + } + + if field.Serializer != nil { + oldValuerOf := field.ValueOf + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + value, zero := oldValuerOf(ctx, v) + + s, ok := value.(SerializerValuerInterface) + if !ok { + s = field.Serializer + } + + return &serializer{ + Field: field, + SerializeValuer: s, + Destination: v, + Context: ctx, + fieldValue: value, + }, zero + } + } + + // ReflectValueOf returns field's reflect value + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(fieldIndex) + } + default: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } + } + } + return v + } + } + + fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { + if v == nil { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else { + reflectV := reflect.ValueOf(v) + // Optimal value type acquisition for v + reflectValType := reflectV.Type() + + if reflectValType.AssignableTo(field.FieldType) { + if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr { + reflectV = reflect.Indirect(reflectV) + } + field.ReflectValueOf(ctx, value).Set(reflectV) + return + } else if reflectValType.ConvertibleTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType)) + return + } else if field.FieldType.Kind() == reflect.Ptr { + fieldValue := field.ReflectValueOf(ctx, value) + fieldType := field.FieldType.Elem() + + if reflectValType.AssignableTo(fieldType) { + if !fieldValue.IsValid() { + fieldValue = reflect.New(fieldType) + } else if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldType)) + } + fieldValue.Elem().Set(reflectV) + return + } else if reflectValType.ConvertibleTo(fieldType) { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldType)) + } + + fieldValue.Elem().Set(reflectV.Convert(fieldType)) + return + } + } + + if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().Elem().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Elem()) + return + } else { + err = setter(ctx, value, reflectV.Elem().Interface()) + } + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = setter(ctx, value, v) + } + } else if _, ok := v.(clause.Expr); !ok { + return fmt.Errorf("failed to set value %#v to field %s", v, field.Name) + } + } + + return + } + + // Set + switch field.FieldType.Kind() { + case reflect.Bool: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { + switch data := v.(type) { + case **bool: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetBool(**data) + } + case bool: + field.ReflectValueOf(ctx, value).SetBool(data) + case int64: + field.ReflectValueOf(ctx, value).SetBool(data > 0) + case string: + b, _ := strconv.ParseBool(data) + field.ReflectValueOf(ctx, value).SetBool(b) + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case **int64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(**data) + } + case **int: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case int64: + field.ReflectValueOf(ctx, value).SetInt(data) + case int: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case int8: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case int16: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case int32: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint8: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint16: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint32: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case uint64: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case float32: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case float64: + field.ReflectValueOf(ctx, value).SetInt(int64(data)) + case []byte: + return field.Set(ctx, value, string(data)) + case string: + if i, err := strconv.ParseInt(data, 0, 64); err == nil { + field.ReflectValueOf(ctx, value).SetInt(i) + } else { + return err + } + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + } else { + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) + } + case *time.Time: + if data != nil { + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + } else { + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) + } + } else { + field.ReflectValueOf(ctx, value).SetInt(0) + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return err + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case **uint64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(**data) + } + case **uint: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case uint64: + field.ReflectValueOf(ctx, value).SetUint(data) + case uint: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case uint8: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case uint16: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case uint32: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int64: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int8: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int16: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case int32: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case float32: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case float64: + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) + case []byte: + return field.Set(ctx, value, string(data)) + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) + } else { + field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) + } + case string: + if i, err := strconv.ParseUint(data, 0, 64); err == nil { + field.ReflectValueOf(ctx, value).SetUint(i) + } else { + return err + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return err + } + case reflect.Float32, reflect.Float64: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case **float64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(**data) + } + case **float32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(float64(**data)) + } + case float64: + field.ReflectValueOf(ctx, value).SetFloat(data) + case float32: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int64: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int8: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int16: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case int32: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint8: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint16: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint32: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case uint64: + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) + case []byte: + return field.Set(ctx, value, string(data)) + case string: + if i, err := strconv.ParseFloat(data, 64); err == nil { + field.ReflectValueOf(ctx, value).SetFloat(i) + } else { + return err + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return err + } + case reflect.String: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case **string: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetString(**data) + } + case string: + field.ReflectValueOf(ctx, value).SetString(data) + case []byte: + field.ReflectValueOf(ctx, value).SetString(string(data)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + field.ReflectValueOf(ctx, value).SetString(utils.ToString(data)) + case float64, float32: + field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return err + } + default: + fieldValue := reflect.New(field.FieldType) + switch fieldValue.Elem().Interface().(type) { + case time.Time: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { + switch data := v.(type) { + case **time.Time: + if data != nil && *data != nil { + field.Set(ctx, value, *data) + } + case time.Time: + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) + case *time.Time: + if data != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem()) + } else { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{})) + } + case string: + if t, err := now.Parse(data); err == nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return nil + } + case *time.Time: + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { + switch data := v.(type) { + case **time.Time: + if data != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) + } + case time.Time: + fieldValue := field.ReflectValueOf(ctx, value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(v)) + case *time.Time: + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) + case string: + if t, err := now.Parse(data); err == nil { + fieldValue := field.ReflectValueOf(ctx, value) + if fieldValue.IsNil() { + if v == "" { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fallbackSetter(ctx, value, v, field.Set) + } + return nil + } + default: + if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + // pointer scanner + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(ctx, value, reflectV.Elem().Interface()) + } + } else { + fieldValue := field.ReflectValueOf(ctx, value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = fieldValue.Interface().(sql.Scanner).Scan(v) + } + return + } + } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(ctx, value, reflectV.Elem().Interface()) + } + } else { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } + } else { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + return fallbackSetter(ctx, value, v, field.Set) + } + } + } + } + + if field.Serializer != nil { + var ( + oldFieldSetter = field.Set + sameElemType bool + sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type() + ) + + if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr { + sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() + } + + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + if s, ok := v.(*serializer); ok { + if s.fieldValue != nil { + err = oldFieldSetter(ctx, value, s.fieldValue) + } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if sameElemType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) + } else if sameType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) + } + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) + s.Serializer = si.Interface().(SerializerInterface) + } + } else { + err = oldFieldSetter(ctx, value, v) + } + return + } + } +} + +func (field *Field) setupNewValuePool() { + if field.Serializer != nil { + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) + return &serializer{ + Field: field, + Serializer: si.Interface().(SerializerInterface), + } + }, + } + } + + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } +} diff --git a/schema/field_test.go b/schema/field_test.go new file mode 100644 index 00000000..be9b50c2 --- /dev/null +++ b/schema/field_test.go @@ -0,0 +1,335 @@ +package schema_test + +import ( + "context" + "database/sql" + "reflect" + "sync" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestFieldValuerAndSetter(t *testing.T) { + var ( + p = &tests.User{} + userSchema, _ = schema.Parse(&p, &sync.Map{}, schema.NamingStrategy{}) + user = tests.User{ + Model: gorm.Model{ + ID: 10, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, + }, + Name: "valuer_and_setter", + Age: 18, + Birthday: tests.Now(), + Active: true, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "name": user.Name, + "id": user.ID, + "created_at": user.CreatedAt, + "updated_at": user.UpdatedAt, + "deleted_at": user.DeletedAt, + "age": user.Age, + "birthday": user.Birthday, + "active": true, + } + checkField(t, userSchema, reflectValue, values) + + var f *bool + // test setter + newValues := map[string]interface{}{ + "name": "valuer_and_setter_2", + "id": 2, + "created_at": time.Now(), + "updated_at": nil, + "deleted_at": time.Now(), + "age": 20, + "birthday": time.Now(), + "active": f, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + newValues["updated_at"] = time.Time{} + newValues["active"] = false + checkField(t, userSchema, reflectValue, newValues) + + // test valuer and other type + age := myint(10) + var nilTime *time.Time + newValues2 := map[string]interface{}{ + "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, + "id": &sql.NullInt64{Int64: 3, Valid: true}, + "created_at": tests.Now(), + "updated_at": nilTime, + "deleted_at": time.Now(), + "age": &age, + "birthday": mytime(time.Now()), + "active": mybool(true), + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + newValues2["updated_at"] = time.Time{} + checkField(t, userSchema, reflectValue, newValues2) +} + +func TestPointerFieldValuerAndSetter(t *testing.T) { + var ( + userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age uint = 18 + active = true + user = User{ + Model: &gorm.Model{ + ID: 10, + CreatedAt: time.Now(), + DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, + }, + Name: &name, + Age: &age, + Birthday: tests.Now(), + Active: &active, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "name": user.Name, + "id": user.ID, + "created_at": user.CreatedAt, + "deleted_at": user.DeletedAt, + "age": user.Age, + "birthday": user.Birthday, + "active": true, + } + checkField(t, userSchema, reflectValue, values) + + // test setter + newValues := map[string]interface{}{ + "name": "valuer_and_setter_2", + "id": 2, + "created_at": time.Now(), + "deleted_at": time.Now(), + "age": 20, + "birthday": time.Now(), + "active": false, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues) + + // test valuer and other type + age2 := myint(10) + newValues2 := map[string]interface{}{ + "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, + "id": &sql.NullInt64{Int64: 3, Valid: true}, + "created_at": tests.Now(), + "deleted_at": time.Now(), + "age": &age2, + "birthday": mytime(time.Now()), + "active": mybool(true), + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) +} + +func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { + var ( + userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + name = "advanced_data_type_valuer_and_setter" + deletedAt = mytime(time.Now()) + isAdmin = mybool(false) + user = AdvancedDataTypeUser{ + ID: sql.NullInt64{Int64: 10, Valid: true}, + Name: &sql.NullString{String: name, Valid: true}, + Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + RegisteredAt: mytime(time.Now()), + DeletedAt: &deletedAt, + Active: mybool(true), + Admin: &isAdmin, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "id": user.ID, + "name": user.Name, + "birthday": user.Birthday, + "registered_at": user.RegisteredAt, + "deleted_at": user.DeletedAt, + "active": user.Active, + "admin": user.Admin, + } + checkField(t, userSchema, reflectValue, values) + + // test setter + newDeletedAt := mytime(time.Now()) + newIsAdmin := mybool(true) + newValues := map[string]interface{}{ + "id": sql.NullInt64{Int64: 1, Valid: true}, + "name": &sql.NullString{String: name + "rename", Valid: true}, + "birthday": time.Now(), + "registered_at": mytime(time.Now()), + "deleted_at": &newDeletedAt, + "active": mybool(false), + "admin": &newIsAdmin, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues) + + newValues2 := map[string]interface{}{ + "id": 5, + "name": name + "rename2", + "birthday": time.Now(), + "registered_at": time.Now(), + "deleted_at": time.Now(), + "active": true, + "admin": false, + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) +} + +type UserWithPermissionControl struct { + ID uint + Name string `gorm:"-"` + Name2 string `gorm:"->"` + Name3 string `gorm:"<-"` + Name4 string `gorm:"<-:create"` + Name5 string `gorm:"<-:update"` + Name6 string `gorm:"<-:create,update"` + Name7 string `gorm:"->:false;<-:create,update"` + Name8 string `gorm:"->;-:migration"` +} + +func TestParseFieldWithPermission(t *testing.T) { + user, err := schema.Parse(&UserWithPermissionControl{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse user with permission, got error %v", err) + } + + fields := []*schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true}, + {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, + {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, + {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, + {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: true}, + {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, + {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, + {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, + {Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"->;-:migration"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true}, + } + + for _, f := range fields { + checkSchemaField(t, user, f, func(f *schema.Field) {}) + } +} + +type ( + ID int64 + INT int + INT8 int8 + INT16 int16 + INT32 int32 + INT64 int64 + UINT uint + UINT8 uint8 + UINT16 uint16 + UINT32 uint32 + UINT64 uint64 + FLOAT32 float32 + FLOAT64 float64 + BOOL bool + STRING string + TIME time.Time + BYTES []byte + + TypeAlias struct { + ID + INT `gorm:"column:fint"` + INT8 `gorm:"column:fint8"` + INT16 `gorm:"column:fint16"` + INT32 `gorm:"column:fint32"` + INT64 `gorm:"column:fint64"` + UINT `gorm:"column:fuint"` + UINT8 `gorm:"column:fuint8"` + UINT16 `gorm:"column:fuint16"` + UINT32 `gorm:"column:fuint32"` + UINT64 `gorm:"column:fuint64"` + FLOAT32 `gorm:"column:ffloat32"` + FLOAT64 `gorm:"column:ffloat64"` + BOOL `gorm:"column:fbool"` + STRING `gorm:"column:fstring"` + TIME `gorm:"column:ftime"` + BYTES `gorm:"column:fbytes"` + } +) + +func TestTypeAliasField(t *testing.T) { + alias, err := schema.Parse(&TypeAlias{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err) + } + + fields := []*schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true}, + {Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`}, + {Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`}, + {Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`}, + {Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`}, + {Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`}, + {Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`}, + {Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`}, + {Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`}, + {Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`}, + {Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`}, + {Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`}, + {Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`}, + {Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`}, + {Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`}, + {Name: "TIME", DBName: "ftime", BindNames: []string{"TIME"}, DataType: schema.Time, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:ftime"`}, + {Name: "BYTES", DBName: "fbytes", BindNames: []string{"BYTES"}, DataType: schema.Bytes, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbytes"`}, + } + + for _, f := range fields { + checkSchemaField(t, alias, f, func(f *schema.Field) {}) + } +} diff --git a/schema/index.go b/schema/index.go new file mode 100644 index 00000000..f5ac5dd2 --- /dev/null +++ b/schema/index.go @@ -0,0 +1,166 @@ +package schema + +import ( + "fmt" + "sort" + "strconv" + "strings" +) + +type Index struct { + Name string + Class string // UNIQUE | FULLTEXT | SPATIAL + Type string // btree, hash, gist, spgist, gin, and brin + Where string + Comment string + Option string // WITH PARSER parser_name + Fields []IndexOption +} + +type IndexOption struct { + *Field + Expression string + Sort string // DESC, ASC + Collate string + Length int + priority int +} + +// ParseIndexes parse schema indexes +func (schema *Schema) ParseIndexes() map[string]Index { + indexes := map[string]Index{} + + for _, field := range schema.Fields { + if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { + fieldIndexes, err := parseFieldIndexes(field) + if err != nil { + schema.err = err + break + } + for _, index := range fieldIndexes { + idx := indexes[index.Name] + idx.Name = index.Name + if idx.Class == "" { + idx.Class = index.Class + } + if idx.Type == "" { + idx.Type = index.Type + } + if idx.Where == "" { + idx.Where = index.Where + } + if idx.Comment == "" { + idx.Comment = index.Comment + } + if idx.Option == "" { + idx.Option = index.Option + } + + idx.Fields = append(idx.Fields, index.Fields...) + sort.Slice(idx.Fields, func(i, j int) bool { + return idx.Fields[i].priority < idx.Fields[j].priority + }) + + indexes[index.Name] = idx + } + } + } + for _, index := range indexes { + if index.Class == "UNIQUE" && len(index.Fields) == 1 { + index.Fields[0].Field.Unique = true + } + } + return indexes +} + +func (schema *Schema) LookIndex(name string) *Index { + if schema != nil { + indexes := schema.ParseIndexes() + for _, index := range indexes { + if index.Name == name { + return &index + } + + for _, field := range index.Fields { + if field.Name == name { + return &index + } + } + } + } + + return nil +} + +func parseFieldIndexes(field *Field) (indexes []Index, err error) { + for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { + if value != "" { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if k == "INDEX" || k == "UNIQUEINDEX" { + var ( + name string + tag = strings.Join(v[1:], ":") + idx = strings.Index(tag, ",") + tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") + settings = ParseTagSetting(tagSetting, ",") + length, _ = strconv.Atoi(settings["LENGTH"]) + ) + + if idx == -1 { + idx = len(tag) + } + + if idx != -1 { + name = tag[0:idx] + } + + if name == "" { + subName := field.Name + const key = "COMPOSITE" + if composite, found := settings[key]; found { + if len(composite) == 0 || composite == key { + err = fmt.Errorf( + "The composite tag of %s.%s cannot be empty", + field.Schema.Name, + field.Name) + return + } + subName = composite + } + name = field.Schema.namer.IndexName( + field.Schema.Table, subName) + } + + if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { + settings["CLASS"] = "UNIQUE" + } + + priority, err := strconv.Atoi(settings["PRIORITY"]) + if err != nil { + priority = 10 + } + + indexes = append(indexes, Index{ + Name: name, + Class: settings["CLASS"], + Type: settings["TYPE"], + Where: settings["WHERE"], + Comment: settings["COMMENT"], + Option: settings["OPTION"], + Fields: []IndexOption{{ + Field: field, + Expression: settings["EXPRESSION"], + Sort: settings["SORT"], + Collate: settings["COLLATE"], + Length: length, + priority: priority, + }}, + }) + } + } + } + + err = nil + return +} diff --git a/schema/index_test.go b/schema/index_test.go new file mode 100644 index 00000000..890327de --- /dev/null +++ b/schema/index_test.go @@ -0,0 +1,185 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "gorm.io/gorm/schema" +) + +type UserIndex struct { + Name string `gorm:"index"` + Name2 string `gorm:"index:idx_name,unique"` + Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"` + Name4 string `gorm:"uniqueIndex"` + Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` + Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` + Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"` + OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` + MemberNumber string `gorm:"index:idx_id,priority:1"` + Name7 string `gorm:"index:type"` + + // Composite Index: Flattened structure. + Data0A string `gorm:"index:,composite:comp_id0"` + Data0B string `gorm:"index:,composite:comp_id0"` + + // Composite Index: Nested structure. + Data1A string `gorm:"index:,composite:comp_id1"` + CompIdxLevel1C + + // Composite Index: Unique and priority. + Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"` + CompIdxLevel2C +} + +type CompIdxLevel1C struct { + CompIdxLevel1B + Data1C string `gorm:"index:,composite:comp_id1"` +} + +type CompIdxLevel1B struct { + Data1B string `gorm:"index:,composite:comp_id1"` +} + +type CompIdxLevel2C struct { + CompIdxLevel2B + Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"` +} + +type CompIdxLevel2B struct { + Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"` +} + +func TestParseIndex(t *testing.T) { + user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user index, got error %v", err) + } + + results := map[string]schema.Index{ + "idx_user_indices_name": { + Name: "idx_user_indices_name", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}}, + }, + "idx_name": { + Name: "idx_name", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", Unique: true}}}, + }, + "idx_user_indices_name3": { + Name: "idx_user_indices_name3", + Type: "btree", + Where: "name3 != 'jinzhu'", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Name3"}, + Sort: "desc", + Collate: "utf8", + Length: 10, + }}, + }, + "idx_user_indices_name4": { + Name: "idx_user_indices_name4", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", Unique: true}}}, + }, + "idx_user_indices_name5": { + Name: "idx_user_indices_name5", + Class: "FULLTEXT", + Comment: "hello , world", + Where: "age > 10", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}}, + }, + "profile": { + Name: "profile", + Comment: "hello , world", + Where: "age > 10", + Option: "WITH PARSER parser_name", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name6"}}, { + Field: &schema.Field{Name: "Age"}, + Expression: "ABS(age)", + }}, + }, + "idx_id": { + Name: "idx_id", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", Unique: true}}}, + }, + "idx_oid": { + Name: "idx_oid", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", Unique: true}}}, + }, + "type": { + Name: "type", + Type: "", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, + }, + "idx_user_indices_comp_id0": { + Name: "idx_user_indices_comp_id0", + Type: "", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data0A"}, + }, { + Field: &schema.Field{Name: "Data0B"}, + }}, + }, + "idx_user_indices_comp_id1": { + Name: "idx_user_indices_comp_id1", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data1A"}, + }, { + Field: &schema.Field{Name: "Data1B"}, + }, { + Field: &schema.Field{Name: "Data1C"}, + }}, + }, + "idx_user_indices_comp_id2": { + Name: "idx_user_indices_comp_id2", + Class: "UNIQUE", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data2C"}, + }, { + Field: &schema.Field{Name: "Data2A"}, + }, { + Field: &schema.Field{Name: "Data2B"}, + }}, + }, + } + + indices := user.ParseIndexes() + + for k, result := range results { + v, ok := indices[k] + if !ok { + t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices) + } + + for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} { + if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { + t.Errorf( + "index %v %v should equal, expects %v, got %v", + k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), + ) + } + } + + for idx, ef := range result.Fields { + rf := v.Fields[idx] + if rf.Field.Name != ef.Field.Name { + t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name) + } + if rf.Field.Unique != ef.Field.Unique { + t.Fatalf("index field '%s' should equal, expects %v, got %v", rf.Field.Name, rf.Field.Unique, ef.Field.Unique) + } + + for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { + if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { + t.Errorf( + "index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, + reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(), + ) + } + } + } + } +} diff --git a/schema/interfaces.go b/schema/interfaces.go new file mode 100644 index 00000000..a75a33c0 --- /dev/null +++ b/schema/interfaces.go @@ -0,0 +1,36 @@ +package schema + +import ( + "gorm.io/gorm/clause" +) + +// GormDataTypeInterface gorm data type interface +type GormDataTypeInterface interface { + GormDataType() string +} + +// FieldNewValuePool field new scan value pool +type FieldNewValuePool interface { + Get() interface{} + Put(interface{}) +} + +// CreateClausesInterface create clauses interface +type CreateClausesInterface interface { + CreateClauses(*Field) []clause.Interface +} + +// QueryClausesInterface query clauses interface +type QueryClausesInterface interface { + QueryClauses(*Field) []clause.Interface +} + +// UpdateClausesInterface update clauses interface +type UpdateClausesInterface interface { + UpdateClauses(*Field) []clause.Interface +} + +// DeleteClausesInterface delete clauses interface +type DeleteClausesInterface interface { + DeleteClauses(*Field) []clause.Interface +} diff --git a/schema/model_test.go b/schema/model_test.go new file mode 100644 index 00000000..9e6c3590 --- /dev/null +++ b/schema/model_test.go @@ -0,0 +1,64 @@ +package schema_test + +import ( + "database/sql" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/utils/tests" +) + +type User struct { + *gorm.Model + Name *string + Age *uint + Birthday *time.Time + Account *tests.Account + Pets []*tests.Pet + Toys []*tests.Toy `gorm:"polymorphic:Owner"` + CompanyID *int + Company *tests.Company + ManagerID *uint + Manager *User + Team []*User `gorm:"foreignkey:ManagerID"` + Languages []*tests.Language `gorm:"many2many:UserSpeak"` + Friends []*User `gorm:"many2many:user_friends"` + Active *bool +} + +type ( + mytime time.Time + myint int + mybool = bool +) + +type AdvancedDataTypeUser struct { + ID sql.NullInt64 + Name *sql.NullString + Birthday sql.NullTime + RegisteredAt mytime + DeletedAt *mytime + Active mybool + Admin *mybool +} + +type BaseModel struct { + ID uint + CreatedAt time.Time + CreatedBy *int + Created *VersionUser `gorm:"foreignKey:CreatedBy"` + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +type VersionModel struct { + BaseModel + Version int +} + +type VersionUser struct { + VersionModel + Name string + Age uint + Birthday *time.Time +} diff --git a/schema/naming.go b/schema/naming.go new file mode 100644 index 00000000..dfd2b9ff --- /dev/null +++ b/schema/naming.go @@ -0,0 +1,188 @@ +package schema + +import ( + "crypto/sha1" + "encoding/hex" + "regexp" + "strings" + "unicode/utf8" + + "github.com/jinzhu/inflection" +) + +// Namer namer interface +type Namer interface { + // TableName 用于将结构体名称转换为表名。 + TableName(table string) string + // SchemaName 定的表名转换为对应的模式(schema)名称。 + SchemaName(table string) string + // ColumnName 用于将结构体字段名和表名转换为列名。 + ColumnName(table, column string) string + // JoinTableName 将指定的联接表名转换为对应的表名。 + JoinTableName(joinTable string) string + // RelationshipFKName 用于将指定的关系名称转换为对应的外键名称。 + RelationshipFKName(Relationship) string + // CheckerName 用于将指定的表名和列名转换为对应的检查约束名称。 + CheckerName(table, column string) string + // IndexName 用于将表名和列名转换为索引名。 + IndexName(table, column string) string +} + +// Replacer replacer interface like strings.Replacer +type Replacer interface { + Replace(name string) string +} + +// NamingStrategy tables, columns naming strategy +type NamingStrategy struct { + TablePrefix string + SingularTable bool + NameReplacer Replacer + NoLowerCase bool +} + +// TableName convert string to table name +func (ns NamingStrategy) TableName(str string) string { + if ns.SingularTable { + return ns.TablePrefix + ns.toDBName(str) + } + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) +} + +// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName +func (ns NamingStrategy) SchemaName(table string) string { + table = strings.TrimPrefix(table, ns.TablePrefix) + + if ns.SingularTable { + return ns.toSchemaName(table) + } + return ns.toSchemaName(inflection.Singular(table)) +} + +// ColumnName convert string to column name +func (ns NamingStrategy) ColumnName(table, column string) string { + return ns.toDBName(column) +} + +// JoinTableName convert string to join table name +func (ns NamingStrategy) JoinTableName(str string) string { + if !ns.NoLowerCase && strings.ToLower(str) == str { + return ns.TablePrefix + str + } + + if ns.SingularTable { + return ns.TablePrefix + ns.toDBName(str) + } + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) +} + +// RelationshipFKName generate fk name for relation +func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { + return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name)) +} + +// CheckerName generate checker name +func (ns NamingStrategy) CheckerName(table, column string) string { + return ns.formatName("chk", table, column) +} + +// IndexName generate index name +func (ns NamingStrategy) IndexName(table, column string) string { + return ns.formatName("idx", table, ns.toDBName(column)) +} + +func (ns NamingStrategy) formatName(prefix, table, name string) string { + formattedName := strings.ReplaceAll(strings.Join([]string{ + prefix, table, name, + }, "_"), ".", "_") + + if utf8.RuneCountInString(formattedName) > 64 { + h := sha1.New() + h.Write([]byte(formattedName)) + bs := h.Sum(nil) + + formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] + } + return formattedName +} + +var ( + // https://github.com/golang/lint/blob/master/lint.go#L770 + commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} + commonInitialismsReplacer *strings.Replacer +) + +func init() { + commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) + for _, initialism := range commonInitialisms { + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + } + commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) +} + +func (ns NamingStrategy) toDBName(name string) string { + if name == "" { + return "" + } + + if ns.NameReplacer != nil { + tmpName := ns.NameReplacer.Replace(name) + + if tmpName == "" { + return name + } + + name = tmpName + } + + if ns.NoLowerCase { + return name + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf strings.Builder + lastCase, nextCase, nextNumber bool // upper case == true + curCase = value[0] <= 'Z' && value[0] >= 'A' + ) + + for i, v := range value[:len(value)-1] { + nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A' + nextNumber = value[i+1] >= '0' && value[i+1] <= '9' + + if curCase { + if lastCase && (nextCase || nextNumber) { + buf.WriteRune(v + 32) + } else { + if i > 0 && value[i-1] != '_' && value[i+1] != '_' { + buf.WriteByte('_') + } + buf.WriteRune(v + 32) + } + } else { + buf.WriteRune(v) + } + + lastCase = curCase + curCase = nextCase + } + + if curCase { + if !lastCase && len(value) > 1 { + buf.WriteByte('_') + } + buf.WriteByte(value[len(value)-1] + 32) + } else { + buf.WriteByte(value[len(value)-1]) + } + ret := buf.String() + return ret +} + +func (ns NamingStrategy) toSchemaName(name string) string { + result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") + for _, initialism := range commonInitialisms { + result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") + } + return result +} diff --git a/schema/naming_test.go b/schema/naming_test.go new file mode 100644 index 00000000..3f598c33 --- /dev/null +++ b/schema/naming_test.go @@ -0,0 +1,210 @@ +package schema + +import ( + "strings" + "testing" +) + +func TestToDBName(t *testing.T) { + maps := map[string]string{ + "": "", + "x": "x", + "X": "x", + "userRestrictions": "user_restrictions", + "ThisIsATest": "this_is_a_test", + "PFAndESI": "pf_and_esi", + "AbcAndJkl": "abc_and_jkl", + "EmployeeID": "employee_id", + "SKU_ID": "sku_id", + "FieldX": "field_x", + "HTTPAndSMTP": "http_and_smtp", + "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", + "UUID": "uuid", + "HTTPURL": "http_url", + "HTTP_URL": "http_url", + "SHA256Hash": "sha256_hash", + "SHA256HASH": "sha256_hash", + "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", + } + + ns := NamingStrategy{} + for key, value := range maps { + if ns.toDBName(key) != value { + t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key)) + } + } + + maps = map[string]string{ + "x": "X", + "user_restrictions": "UserRestriction", + "this_is_a_test": "ThisIsATest", + "abc_and_jkl": "AbcAndJkl", + "employee_id": "EmployeeID", + "field_x": "FieldX", + "http_and_smtp": "HTTPAndSMTP", + "http_server_handler_for_url_id": "HTTPServerHandlerForURLID", + "uuid": "UUID", + "http_url": "HTTPURL", + "sha256_hash": "Sha256Hash", + "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id": "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIDCanBeUsedAtTheEndAsID", + } + for key, value := range maps { + if ns.SchemaName(key) != value { + t.Errorf("%v schema name should equal %v, but got %v", key, value, ns.SchemaName(key)) + } + } +} + +func TestNamingStrategy(t *testing.T) { + ns := NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: strings.NewReplacer("CID", "Cid"), + } + idxName := ns.IndexName("public.table", "name") + + if idxName != "idx_public_table_name" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.user_language" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.company" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "name_cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +type CustomReplacer struct { + f func(string) string +} + +func (r CustomReplacer) Replace(name string) string { + return r.f(name) +} + +func TestCustomReplacer(t *testing.T) { + ns := NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: false, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_replaced_name" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { // Seems like a bug in NamingStrategy to skip the Replacer when the name is lowercase here. + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.replaced_userlanguage" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.replaced_company" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "replaced_name_cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +func TestCustomReplacerWithNoLowerCase(t *testing.T) { + ns := NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: true, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_REPLACED_NAME" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.REPLACED_USER_LANGUAGES" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.REPLACED_USERLANGUAGE" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.REPLACED_COMPANY" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "REPLACED_NAME_Cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { + ns := NamingStrategy{} + + formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") + if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { + t.Errorf("invalid formatted name generated, got %v", formattedName) + } +} + +func TestReplaceEmptyTableName(t *testing.T) { + ns := NamingStrategy{ + SingularTable: true, + NameReplacer: strings.NewReplacer("Model", ""), + } + tableName := ns.TableName("Model") + if tableName != "Model" { + t.Errorf("invalid table name generated, got %v", tableName) + } +} diff --git a/schema/pool.go b/schema/pool.go new file mode 100644 index 00000000..fa62fe22 --- /dev/null +++ b/schema/pool.go @@ -0,0 +1,19 @@ +package schema + +import ( + "reflect" + "sync" +) + +// sync pools +var ( + normalPool sync.Map + poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { + v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ + New: func() interface{} { + return reflect.New(reflectType).Interface() + }, + }) + return v.(FieldNewValuePool) + } +) diff --git a/schema/relationship.go b/schema/relationship.go new file mode 100644 index 00000000..e03dcc52 --- /dev/null +++ b/schema/relationship.go @@ -0,0 +1,699 @@ +package schema + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/jinzhu/inflection" + "gorm.io/gorm/clause" +) + +// RelationshipType relationship type +type RelationshipType string + +const ( + HasOne RelationshipType = "has_one" // HasOneRel has one relationship + HasMany RelationshipType = "has_many" // HasManyRel has many relationship + BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship + Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship + has RelationshipType = "has" +) + +type Relationships struct { + HasOne []*Relationship + BelongsTo []*Relationship + HasMany []*Relationship + Many2Many []*Relationship + Relations map[string]*Relationship + + EmbeddedRelations map[string]*Relationships +} + +type Relationship struct { + Name string + Type RelationshipType + Field *Field + Polymorphic *Polymorphic + References []*Reference + Schema *Schema + FieldSchema *Schema + JoinTable *Schema + foreignKeys, primaryKeys []string +} + +type Polymorphic struct { + PolymorphicID *Field + PolymorphicType *Field + Value string +} + +type Reference struct { + PrimaryKey *Field + PrimaryValue string + ForeignKey *Field + OwnPrimaryKey bool +} + +func (schema *Schema) parseRelation(field *Field) *Relationship { + var ( + err error + fieldValue = reflect.New(field.IndirectFieldType).Interface() + relation = &Relationship{ + Name: field.Name, + Field: field, + Schema: schema, + foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), + primaryKeys: toColumns(field.TagSettings["REFERENCES"]), + } + ) + + cacheStore := schema.cacheStore + + if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + schema.err = err + return nil + } + + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + schema.buildPolymorphicRelation(relation, field, polymorphic) + } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + schema.buildMany2ManyRelation(relation, field, many2many) + } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { + schema.guessRelation(relation, field, guessBelongs) + } else { + switch field.IndirectFieldType.Kind() { + case reflect.Struct: + schema.guessRelation(relation, field, guessGuess) + case reflect.Slice: + schema.guessRelation(relation, field, guessHas) + default: + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) + } + } + + if relation.Type == has { + // don't add relations to embedded schema, which might be shared + if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { + relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation + } + + switch field.IndirectFieldType.Kind() { + case reflect.Struct: + relation.Type = HasOne + case reflect.Slice: + relation.Type = HasMany + } + } + + if schema.err == nil { + schema.setRelation(relation) + switch relation.Type { + case HasOne: + schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) + case HasMany: + schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation) + case BelongsTo: + schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation) + case Many2Many: + schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) + } + } + + return relation +} + +func (schema *Schema) setRelation(relation *Relationship) { + // set non-embedded relation + if rel := schema.Relationships.Relations[relation.Name]; rel != nil { + if len(rel.Field.BindNames) > 1 { + schema.Relationships.Relations[relation.Name] = relation + } + } else { + schema.Relationships.Relations[relation.Name] = relation + } + + // set embedded relation + if len(relation.Field.BindNames) <= 1 { + return + } + relationships := &schema.Relationships + for i, name := range relation.Field.BindNames { + if i < len(relation.Field.BindNames)-1 { + if relationships.EmbeddedRelations == nil { + relationships.EmbeddedRelations = map[string]*Relationships{} + } + if r := relationships.EmbeddedRelations[name]; r == nil { + relationships.EmbeddedRelations[name] = &Relationships{} + } + relationships = relationships.EmbeddedRelations[name] + } else { + if relationships.Relations == nil { + relationships.Relations = map[string]*Relationship{} + } + relationships.Relations[relation.Name] = relation + } + } +} + +// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` +// +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { + relation.Polymorphic = &Polymorphic{ + Value: schema.Table, + PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], + PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + } + + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { + relation.Polymorphic.Value = strings.TrimSpace(value) + } + + if relation.Polymorphic.PolymorphicType == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + } + + if relation.Polymorphic.PolymorphicID == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + } + + if schema.err == nil { + relation.References = append(relation.References, &Reference{ + PrimaryValue: relation.Polymorphic.Value, + ForeignKey: relation.Polymorphic.PolymorphicType, + }) + + primaryKeyField := schema.PrioritizedPrimaryField + if len(relation.foreignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) + } + } + + if primaryKeyField == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) + return + } + + // use same data type for foreign keys + if copyableDataType(primaryKeyField.DataType) { + relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + } + relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType + if relation.Polymorphic.PolymorphicID.Size == 0 { + relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicID, + OwnPrimaryKey: true, + }) + } + + relation.Type = has +} + +func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { + relation.Type = Many2Many + + var ( + err error + joinTableFields []reflect.StructField + fieldsMap = map[string]*Field{} + ownFieldsMap = map[string]*Field{} // fix self join many2many + referFieldsMap = map[string]*Field{} + joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) + joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) + ) + + ownForeignFields := schema.PrimaryFields + refForeignFields := relation.FieldSchema.PrimaryFields + + if len(relation.foreignKeys) > 0 { + ownForeignFields = []*Field{} + for _, foreignKey := range relation.foreignKeys { + if field := schema.LookUpField(foreignKey); field != nil { + ownForeignFields = append(ownForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) + return + } + } + } + + if len(relation.primaryKeys) > 0 { + refForeignFields = []*Field{} + for _, foreignKey := range relation.primaryKeys { + if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { + refForeignFields = append(refForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) + return + } + } + } + + for idx, ownField := range ownForeignFields { + joinFieldName := strings.Title(schema.Name) + ownField.Name + if len(joinForeignKeys) > idx { + joinFieldName = strings.Title(joinForeignKeys[idx]) + } + + ownFieldsMap[joinFieldName] = ownField + fieldsMap[joinFieldName] = ownField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: ownField.StructField.PkgPath, + Type: ownField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } + + for idx, relField := range refForeignFields { + joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name + + if _, ok := ownFieldsMap[joinFieldName]; ok { + if field.Name != relation.FieldSchema.Name { + joinFieldName = inflection.Singular(field.Name) + relField.Name + } else { + joinFieldName += "Reference" + } + } + + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + referFieldsMap[joinFieldName] = relField + + if _, ok := fieldsMap[joinFieldName]; !ok { + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } + } + + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: strings.Title(schema.Name) + field.Name, + Type: schema.ModelType, + Tag: `gorm:"-"`, + }) + + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + schema.err = err + } + relation.JoinTable.Name = many2many + relation.JoinTable.Table = schema.namer.JoinTableName(many2many) + relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields)) + + relName := relation.Schema.Name + relRefName := relation.FieldSchema.Name + if relName == relRefName { + relRefName = relation.Field.Name + } + + if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok { + relation.JoinTable.Relationships.Relations[relName] = &Relationship{ + Name: relName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.Schema, + } + } else { + relation.JoinTable.Relationships.Relations[relName].References = []*Reference{} + } + + if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok { + relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{ + Name: relRefName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.FieldSchema, + } + } else { + relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{} + } + + // build references + for _, f := range relation.JoinTable.Fields { + if f.Creatable || f.Readable || f.Updatable { + // use same data type for foreign keys + if copyableDataType(fieldsMap[f.Name].DataType) { + f.DataType = fieldsMap[f.Name].DataType + } + f.GORMDataType = fieldsMap[f.Name].GORMDataType + if f.Size == 0 { + f.Size = fieldsMap[f.Name].Size + } + relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) + + if of, ok := ownFieldsMap[f.Name]; ok { + joinRel := relation.JoinTable.Relationships.Relations[relName] + joinRel.Field = relation.Field + joinRel.References = append(joinRel.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + OwnPrimaryKey: true, + }) + } + + if rf, ok := referFieldsMap[f.Name]; ok { + joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] + if joinRefRel.Field == nil { + joinRefRel.Field = relation.Field + } + joinRefRel.References = append(joinRefRel.References, &Reference{ + PrimaryKey: rf, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: rf, + ForeignKey: f, + }) + } + } + } +} + +type guessLevel int + +const ( + guessGuess guessLevel = iota + guessBelongs + guessEmbeddedBelongs + guessHas + guessEmbeddedHas +) + +func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) { + var ( + primaryFields, foreignFields []*Field + primarySchema, foreignSchema = schema, relation.FieldSchema + gl = cgl + ) + + if gl == guessGuess { + if field.Schema == relation.FieldSchema { + gl = guessBelongs + } else { + gl = guessHas + } + } + + reguessOrErr := func() { + switch cgl { + case guessGuess: + schema.guessRelation(relation, field, guessBelongs) + case guessBelongs: + schema.guessRelation(relation, field, guessEmbeddedBelongs) + case guessEmbeddedBelongs: + schema.guessRelation(relation, field, guessHas) + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + // case guessEmbeddedHas: + default: + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) + } + } + + switch gl { + case guessBelongs: + primarySchema, foreignSchema = relation.FieldSchema, schema + case guessEmbeddedBelongs: + if field.OwnerSchema == nil { + reguessOrErr() + return + } + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema + case guessHas: + case guessEmbeddedHas: + if field.OwnerSchema == nil { + reguessOrErr() + return + } + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema + } + + if len(relation.foreignKeys) > 0 { + for _, foreignKey := range relation.foreignKeys { + f := foreignSchema.LookUpField(foreignKey) + if f == nil { + reguessOrErr() + return + } + foreignFields = append(foreignFields, f) + } + } else { + primarySchemaName := primarySchema.Name + if primarySchemaName == "" { + primarySchemaName = relation.FieldSchema.Name + } + + if len(relation.primaryKeys) > 0 { + for _, primaryKey := range relation.primaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + primaryFields = append(primaryFields, f) + } + } + } else { + primaryFields = primarySchema.PrimaryFields + } + + primaryFieldLoop: + for _, primaryField := range primaryFields { + lookUpName := primarySchemaName + primaryField.Name + if gl == guessBelongs { + lookUpName = field.Name + primaryField.Name + } + + lookUpNames := []string{lookUpName} + if len(primaryFields) == 1 { + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + } + + for _, name := range lookUpNames { + if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + continue primaryFieldLoop + } + } + for _, name := range lookUpNames { + if f := foreignSchema.LookUpField(name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + continue primaryFieldLoop + } + } + } + } + + switch { + case len(foreignFields) == 0: + reguessOrErr() + return + case len(relation.primaryKeys) > 0: + for idx, primaryKey := range relation.primaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + if len(primaryFields) < idx+1 { + primaryFields = append(primaryFields, f) + } else if f != primaryFields[idx] { + reguessOrErr() + return + } + } else { + reguessOrErr() + return + } + } + case len(primaryFields) == 0: + if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { + primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) + } else if len(primarySchema.PrimaryFields) == len(foreignFields) { + primaryFields = append(primaryFields, primarySchema.PrimaryFields...) + } else { + reguessOrErr() + return + } + } + + // build references + for idx, foreignField := range foreignFields { + // use same data type for foreign keys + if copyableDataType(primaryFields[idx].DataType) { + foreignField.DataType = primaryFields[idx].DataType + } + foreignField.GORMDataType = primaryFields[idx].GORMDataType + if foreignField.Size == 0 { + foreignField.Size = primaryFields[idx].Size + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: primaryFields[idx], + ForeignKey: foreignField, + OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), + }) + } + + if gl == guessHas || gl == guessEmbeddedHas { + relation.Type = has + } else { + relation.Type = BelongsTo + } +} + +type Constraint struct { + Name string + Field *Field + Schema *Schema + ForeignKeys []*Field + ReferenceSchema *Schema + References []*Field + OnDelete string + OnUpdate string +} + +func (rel *Relationship) ParseConstraint() *Constraint { + str := rel.Field.TagSettings["CONSTRAINT"] + if str == "-" { + return nil + } + + if rel.Type == BelongsTo { + for _, r := range rel.FieldSchema.Relationships.Relations { + if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) { + matched := true + for idx, ref := range r.References { + if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && + rel.References[idx].PrimaryValue == ref.PrimaryValue) { + matched = false + } + } + + if matched { + return nil + } + } + } + } + + var ( + name string + idx = strings.Index(str, ",") + settings = ParseTagSetting(str, ",") + ) + + // optimize match english letters and midline + // The following code is basically called in for. + // In order to avoid the performance problems caused by repeated compilation of regular expressions, + // it only needs to be done once outside, so optimization is done here. + if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) { + name = str[0:idx] + } else { + name = rel.Schema.namer.RelationshipFKName(*rel) + } + + constraint := Constraint{ + Name: name, + Field: rel.Field, + OnUpdate: settings["ONUPDATE"], + OnDelete: settings["ONDELETE"], + } + + for _, ref := range rel.References { + if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) { + constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) + constraint.References = append(constraint.References, ref.PrimaryKey) + + if ref.OwnPrimaryKey { + constraint.Schema = ref.ForeignKey.Schema + constraint.ReferenceSchema = rel.Schema + } else { + constraint.Schema = rel.Schema + constraint.ReferenceSchema = ref.PrimaryKey.Schema + } + } + } + + return &constraint +} + +func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) { + table := rel.FieldSchema.Table + foreignFields := []*Field{} + relForeignKeys := []string{} + + if rel.JoinTable != nil { + table = rel.JoinTable.Table + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + } + + _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields) + column, values := ToQueryValues(table, relForeignKeys, foreignValues) + + conds = append(conds, clause.IN{Column: column, Values: values}) + return +} + +func copyableDataType(str DataType) bool { + for _, s := range []string{"auto_increment", "primary key"} { + if strings.Contains(strings.ToLower(string(str)), s) { + return false + } + } + return true +} diff --git a/schema/relationship_test.go b/schema/relationship_test.go new file mode 100644 index 00000000..732f6f75 --- /dev/null +++ b/schema/relationship_test.go @@ -0,0 +1,787 @@ +package schema_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { + if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { + t.Errorf("Failed to parse schema, got error %v", err) + } else { + for _, rel := range relations { + checkSchemaRelation(t, s, rel) + } + } +} + +func TestBelongsToOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileRefer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + +func TestBelongsToOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileID;References:Refer"` + ProfileID int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, + }) +} + +func TestBelongsToWithOnlyReferences(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"References:Refer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + +func TestBelongsToWithOnlyReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"References:Refer"` + ProfileID int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, + }) +} + +func TestSelfReferentialBelongsTo(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatorID *int32 + Creator *User + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatorID", "User", "", false}}, + }) +} + +func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy *int32 + Creator *User `gorm:"foreignKey:CreatedBy;references:ID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}}, + }) +} + +func TestHasOneOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasOneOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"ForeignKey:UserID;References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestHasOneOverrideReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + ProfileID uint `gorm:"column:profile_id"` + Profile *Profile `gorm:"foreignKey:ID;references:ProfileID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ProfileID", "User", "ID", "Profile", "", true}}, + }) +} + +func TestHasOneWithOnlyReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasOneWithOnlyReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestHasManyOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile []Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasManyOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile []Profile `gorm:"ForeignKey:UserID;References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"` + Profiles2 []Profile `gorm:"many2many:user_profiles2;ForeignKey:refer;JoinForeignKey:user_refer_id;References:user_refer;JoinReferences:profile_refer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserReferID", "user_profiles", "", true}, + {"UserRefer", "Profile", "ProfileRefer", "user_profiles", "", false}, + }, + }, Relation{ + Name: "Profiles2", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles2", Table: "user_profiles2"}, + References: []Reference{ + {"Refer", "User", "User_refer_id", "user_profiles2", "", true}, + {"UserRefer", "Profile", "Profile_refer", "user_profiles2", "", false}, + }, + }) +} + +func TestMany2ManyOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;References:UserRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"UserRefer", "Profile", "ProfileUserRefer", "user_profiles", "", false}, + }, + }) +} + +func TestMany2ManySharedForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + Kind string + ProfileRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"` + Kind string + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"Kind", "User", "Kind", "user_profiles", "", true}, + {"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false}, + {"Kind", "Profile", "Kind", "user_profiles", "", false}, + }, + }) +} + +func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, + References: []Reference{ + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, + }, + }) +} + +func TestBuildReadonlyMany2ManyRelation(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, + References: []Reference{ + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, + }, + }) +} + +func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { + type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string + } + + type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` + } + + checkStructRelation(t, &Blog{}, + Relation{ + Name: "Tags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "blog_tags", Table: "blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "blog_tags", "", true}, + {"Locale", "Blog", "BlogLocale", "blog_tags", "", true}, + {"ID", "Tag", "TagID", "blog_tags", "", false}, + {"Locale", "Tag", "TagLocale", "blog_tags", "", false}, + }, + }, + Relation{ + Name: "SharedTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "shared_blog_tags", Table: "shared_blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "shared_blog_tags", "", true}, + {"ID", "Tag", "TagID", "shared_blog_tags", "", false}, + }, + }, + Relation{ + Name: "LocaleTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "locale_blog_tags", Table: "locale_blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "locale_blog_tags", "", true}, + {"Locale", "Blog", "BlogLocale", "locale_blog_tags", "", true}, + {"ID", "Tag", "TagID", "locale_blog_tags", "", false}, + }, + }, + ) +} + +func TestMultipleMany2Many(t *testing.T) { + type Thing struct { + ID int + } + + type Person struct { + ID int + Likes []Thing `gorm:"many2many:likes"` + Dislikes []Thing `gorm:"many2many:dislikes"` + } + + checkStructRelation(t, &Person{}, + Relation{ + Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "likes", Table: "likes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "likes", "", true}, + {"ID", "Thing", "ThingID", "likes", "", false}, + }, + }, + Relation{ + Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "dislikes", "", true}, + {"ID", "Thing", "ThingID", "dislikes", "", false}, + }, + }, + ) +} + +func TestSelfReferentialMany2Many(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy int32 + Creators []User `gorm:"foreignKey:CreatedBy"` + AnotherPro interface{} `gorm:"-"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creators", Type: schema.HasMany, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", true}}, + }) + + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse schema") + } + + relSchema := user.Relationships.Relations["Creators"].FieldSchema + if user != relSchema { + t.Fatalf("schema should be same, expects %p but got %p", user, relSchema) + } +} + +type CreatedByModel struct { + CreatedByID uint + CreatedBy *CreatedUser +} + +type CreatedUser struct { + gorm.Model + CreatedByModel +} + +func TestEmbeddedRelation(t *testing.T) { + checkStructRelation(t, &CreatedUser{}, Relation{ + Name: "CreatedBy", Type: schema.BelongsTo, Schema: "CreatedUser", FieldSchema: "CreatedUser", + References: []Reference{ + {"ID", "CreatedUser", "CreatedByID", "CreatedUser", "", false}, + }, + }) + + userSchema, err := schema.Parse(&CreatedUser{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse schema, got error %v", err) + } + + if len(userSchema.Relationships.Relations) != 1 { + t.Fatalf("expects 1 relations, but got %v", len(userSchema.Relationships.Relations)) + } + + if createdByRel, ok := userSchema.Relationships.Relations["CreatedBy"]; ok { + if createdByRel.FieldSchema != userSchema { + t.Fatalf("expects same field schema, but got new %p, old %p", createdByRel.FieldSchema, userSchema) + } + } else { + t.Fatalf("expects created by relations, but not found") + } +} + +func TestEmbeddedHas(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + type User struct { + ID int + Cat struct { + Name string + Toy Toy `gorm:"polymorphic:Owner;"` + Toys []Toy `gorm:"polymorphic:Owner;"` + } `gorm:"embedded;embeddedPrefix:cat_"` + Dog struct { + ID int + Name string + UserID int + Toy Toy `gorm:"polymorphic:Owner;"` + Toys []Toy `gorm:"polymorphic:Owner;"` + } + Toys []Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) +} + +func TestEmbeddedBelongsTo(t *testing.T) { + type Country struct { + ID int `gorm:"primaryKey"` + Name string + } + type Address struct { + CountryID int + Country Country + } + type NestedAddress struct { + Address + } + type Org struct { + ID int + PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"` + VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"` + AddressID int + Address struct { + ID int + Address + } + NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + } + + s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Errorf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "PostalAddress": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + "VisitingAddress": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + "NestedAddress": { + EmbeddedRelations: map[string]EmbeddedRelations{ + "Address": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + }, + }, + }) +} + +func TestVariableRelation(t *testing.T) { + var result struct { + User + } + + checkStructRelation(t, &result, Relation{ + Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account", + References: []Reference{ + {"ID", "", "UserID", "Account", "", true}, + }, + }) + + checkStructRelation(t, &result, Relation{ + Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company", + References: []Reference{ + {"ID", "Company", "CompanyID", "", "", false}, + }, + }) +} + +func TestSameForeignKey(t *testing.T) { + type UserAux struct { + gorm.Model + Aux string + UUID string + } + + type User struct { + gorm.Model + Name string + UUID string + Aux *UserAux `gorm:"foreignkey:UUID;references:UUID"` + } + + checkStructRelation(t, &User{}, + Relation{ + Name: "Aux", Type: schema.HasOne, Schema: "User", FieldSchema: "UserAux", + References: []Reference{ + {"UUID", "User", "UUID", "UserAux", "", true}, + }, + }, + ) +} + +func TestBelongsToSameForeignKey(t *testing.T) { + type User struct { + gorm.Model + Name string + UUID string + } + + type UserAux struct { + gorm.Model + Aux string + UUID string + User User `gorm:"ForeignKey:UUID;references:UUID;belongsTo"` + } + + checkStructRelation(t, &UserAux{}, + Relation{ + Name: "User", Type: schema.BelongsTo, Schema: "UserAux", FieldSchema: "User", + References: []Reference{ + {"UUID", "User", "UUID", "UserAux", "", false}, + }, + }, + ) +} + +func TestHasOneWithSameForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + ProfileRefer int // not used in relationship + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ID;references:ProfileRefer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ProfileRefer", "User", "ID", "Profile", "", true}}, + }) +} + +func TestHasManySameForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + UserRefer uint + Profile []Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} + +type Author struct { + gorm.Model +} + +type Book struct { + gorm.Model + Author Author + AuthorID uint +} + +func (Book) TableName() string { + return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name" +} + +func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { + s, err := schema.Parse( + &Book{}, + &sync.Map{}, + schema.NamingStrategy{}, + ) + if err != nil { + t.Fatalf("Failed to parse schema") + } + + expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec" + constraint := s.Relationships.Relations["Author"].ParseConstraint() + + if constraint.Name != expectedConstraintName { + t.Fatalf( + "expected constraint name %s, got %s", + expectedConstraintName, + constraint.Name, + ) + } +} diff --git a/schema/schema.go b/schema/schema.go new file mode 100644 index 00000000..b0621d69 --- /dev/null +++ b/schema/schema.go @@ -0,0 +1,370 @@ +package schema + +import ( + "context" + "errors" + "fmt" + "go/ast" + "reflect" + "strings" + "sync" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" +) + +// ErrUnsupportedDataType unsupported data type +var ErrUnsupportedDataType = errors.New("unsupported data type") + +type Schema struct { + Name string // model 结构体的 Name + ModelType reflect.Type // model 结构体的类型 + Table string // 该 schema 结构体对应的 db 的表名 + PrioritizedPrimaryField *Field + DBNames []string + PrimaryFields []*Field + PrimaryFieldDBNames []string + Fields []*Field + FieldsByName map[string]*Field + FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' + FieldsByDBName map[string]*Field + FieldsWithDefaultDBValue []*Field // fields with default value assigned by database + Relationships Relationships + CreateClauses []clause.Interface + QueryClauses []clause.Interface + UpdateClauses []clause.Interface + DeleteClauses []clause.Interface + BeforeCreate, AfterCreate bool + BeforeUpdate, AfterUpdate bool + BeforeDelete, AfterDelete bool + BeforeSave, AfterSave bool + AfterFind bool + err error + initialized chan struct{} + namer Namer + cacheStore *sync.Map +} + +func (schema Schema) String() string { + if schema.ModelType.Name() == "" { + return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) + } + return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) +} + +func (schema Schema) MakeSlice() reflect.Value { + slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) + results := reflect.New(slice.Type()) + results.Elem().Set(slice) + return results +} + +func (schema Schema) LookUpField(name string) *Field { + if field, ok := schema.FieldsByDBName[name]; ok { + return field + } + if field, ok := schema.FieldsByName[name]; ok { + return field + } + return nil +} + +// LookUpFieldByBindName looks for the closest field in the embedded struct. +// +// type Struct struct { +// Embedded struct { +// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID") +// } +// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") +// } +func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { + if len(bindNames) == 0 { + return nil + } + for i := len(bindNames) - 1; i >= 0; i-- { + find := strings.Join(bindNames[:i], ".") + "." + name + if field, ok := schema.FieldsByBindName[find]; ok { + return field + } + } + return nil +} + +type Tabler interface { + TableName() string +} + +type TablerWithNamer interface { + TableName(Namer) string +} + +// Parse get data type from dialector +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + return ParseWithSpecialTableName(dest, cacheStore, namer, "") +} + +// ParseWithSpecialTableName get data type from dialector with extra schema table +func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { + if dest == nil { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + + value := reflect.ValueOf(dest) + if value.Kind() == reflect.Ptr && value.IsNil() { + value = reflect.New(value.Type().Elem()) // 如果是类型非空,但是指为空的指针,new 一个实例 + } + modelType := reflect.Indirect(value).Type() // 如果 dest 的 type 是指针,取出实际的类型 + + if modelType.Kind() == reflect.Interface { // 如果 dest 是一个接口,取出接口的实际类型 + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() // 如果是 slice 或者 array, 或者指针, 取出实际的类型,可以取多层 + } + + if modelType.Kind() != reflect.Struct { // 经过上面的处理,这里 modelType 一定是一个结构体了 + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + // Cache the Schema for performance, + // Use the modelType or modelType + schemaTable (if it present) as cache key. + var schemaCacheKey interface{} + if specialTableName != "" { // 生成 model 缓存的 key, + schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) // 如果指定了别名,使用 type+别名作为 key + } else { + schemaCacheKey = modelType // 如果没指定别名,直接使用 modelType 作为 key + } + + // Load exist schema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { // 如果找到缓存,就直接用缓存 + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized // 缓存里面的 Schema 可能没初始化,需要等待初始化完成或者失败 + return s, s.err + } + + modelValue := reflect.New(modelType) // 根据结构体的 type, New 一个 结构体 + tableName := namer.TableName(modelType.Name()) // 调用 namer.TableName 生成一个表名 + if tabler, ok := modelValue.Interface().(Tabler); ok { + tableName = tabler.TableName() // 如果 model 结构体实现了 Tabler 接口,优先使用 TableName 方法指定的名字 + } + if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + tableName = tabler.TableName(namer) // 如果 model 结构体实现了 TablerWithNamer 接口,优先使用 TableName 方法指定的名字 + } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table // 如果这个结构体是一个嵌套结构体,使用所在结构体的 tableName + } + if specialTableName != "" && specialTableName != tableName { + tableName = specialTableName // 如果指定了 specialTableName,优先用指定的 specialTableName 作为 tableName + } + + schema := &Schema{ + Name: modelType.Name(), + ModelType: modelType, + Table: tableName, + FieldsByName: map[string]*Field{}, + FieldsByBindName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, + cacheStore: cacheStore, + namer: namer, + initialized: make(chan struct{}), + } + // When the schema initialization is completed, the channel will be closed + defer close(schema.initialized) + + // Load exist schema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { // 再次检查,如果已经在缓存里面存在了,就等待初始化完成,然后返还结果 + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } + + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { // 解析每一个导出的字段 + if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { + schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) // 如果有嵌套结构体字段,将其所有字段的 schema 合并到当前结构体 + } else { + schema.Fields = append(schema.Fields, field) // 如果不是嵌套结构体,添加到 Fileds + } + } + } + + for _, field := range schema.Fields { + if field.DBName == "" && field.DataType != "" { + field.DBName = namer.ColumnName(schema.Table, field.Name) + } + + bindName := field.BindName() + if field.DBName != "" { + // nonexistence or shortest path or first appear prioritized if has permission + if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { + if _, ok := schema.FieldsByDBName[field.DBName]; !ok { + schema.DBNames = append(schema.DBNames, field.DBName) + } + schema.FieldsByDBName[field.DBName] = field + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[bindName] = field + + if v != nil && v.PrimaryKey { + for idx, f := range schema.PrimaryFields { + if f == v { + schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) + } + } + } + + if field.PrimaryKey { + schema.PrimaryFields = append(schema.PrimaryFields, field) + } + } + } + + if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { + schema.FieldsByName[field.Name] = field + } + if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" { + schema.FieldsByBindName[bindName] = field + } + + field.setupValuerAndSetter() + } + + prioritizedPrimaryField := schema.LookUpField("id") + if prioritizedPrimaryField == nil { + prioritizedPrimaryField = schema.LookUpField("ID") + } + + if prioritizedPrimaryField != nil { + if prioritizedPrimaryField.PrimaryKey { + schema.PrioritizedPrimaryField = prioritizedPrimaryField + } else if len(schema.PrimaryFields) == 0 { + prioritizedPrimaryField.PrimaryKey = true + schema.PrioritizedPrimaryField = prioritizedPrimaryField + schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField) + } + } + + if schema.PrioritizedPrimaryField == nil { + if len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } else if len(schema.PrimaryFields) > 1 { + // If there are multiple primary keys, the AUTOINCREMENT field is prioritized + for _, field := range schema.PrimaryFields { + if field.AutoIncrement { + schema.PrioritizedPrimaryField = field + break + } + } + } + } + + for _, field := range schema.PrimaryFields { + schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) + } + + for _, field := range schema.Fields { + if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + } + + if field := schema.PrioritizedPrimaryField; field != nil { + switch field.GORMDataType { + case Int, Uint: + if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + + field.HasDefaultValue = true + field.AutoIncrement = true + } + } + } + + callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} + for _, name := range callbacks { + if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": // TODO hack + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) + } + } + } + + // Cache the schema + if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } + + defer func() { + if schema.err != nil { + logger.Default.Error(context.Background(), schema.err.Error()) + cacheStore.Delete(modelType) + } + }() + + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } else { + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + } + + fieldValue := reflect.New(field.IndirectFieldType) + fieldInterface := fieldValue.Interface() + if fc, ok := fieldInterface.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldInterface.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldInterface.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldInterface.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } + } + + return schema, schema.err +} + +func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + if v, ok := cacheStore.Load(modelType); ok { + return v.(*Schema), nil + } + + return Parse(dest, cacheStore, namer) +} diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go new file mode 100644 index 00000000..605aa03a --- /dev/null +++ b/schema/schema_helper_test.go @@ -0,0 +1,242 @@ +package schema_test + +import ( + "context" + "fmt" + "reflect" + "strings" + "testing" + + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { + t.Run("CheckSchema/"+s.Name, func(t *testing.T) { + tests.AssertObjEqual(t, s, v, "Name", "Table") + + for idx, field := range primaryFields { + var found bool + for _, f := range s.PrimaryFields { + if f.Name == field { + found = true + } + } + + if idx == 0 { + if field != s.PrioritizedPrimaryField.Name { + t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name) + } + } + + if !found { + t.Errorf("schema %v failed to found primary key: %v", s, field) + } + } + }) +} + +func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) { + t.Run("CheckField/"+f.Name, func(t *testing.T) { + if fc != nil { + fc(f) + } + + if f.TagSettings == nil { + if f.Tag != "" { + f.TagSettings = schema.ParseTagSetting(f.Tag.Get("gorm"), ";") + } else { + f.TagSettings = map[string]string{} + } + } + + parsedField, ok := s.FieldsByDBName[f.DBName] + if !ok { + parsedField, ok = s.FieldsByName[f.Name] + } + + if !ok { + t.Errorf("schema %v failed to look up field with name %v", s, f.Name) + } else { + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings") + + if f.DBName != "" { + if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + } + + for _, name := range []string{f.DBName, f.Name} { + if name != "" { + if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + } + } + + if f.PrimaryKey { + var found bool + for _, primaryField := range s.PrimaryFields { + if primaryField == parsedField { + found = true + } + } + + if !found { + t.Errorf("schema %v doesn't include field %v", s, f.Name) + } + } + } + }) +} + +type Relation struct { + Name string + Type schema.RelationshipType + Schema string + FieldSchema string + Polymorphic Polymorphic + JoinTable JoinTable + References []Reference +} + +type Polymorphic struct { + ID string + Type string + Value string +} + +type JoinTable struct { + Name string + Table string + Fields []schema.Field +} + +type Reference struct { + PrimaryKey string + PrimarySchema string + ForeignKey string + ForeignSchema string + PrimaryValue string + OwnPrimaryKey bool +} + +func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { + t.Run("CheckRelation/"+relation.Name, func(t *testing.T) { + if r, ok := s.Relationships.Relations[relation.Name]; ok { + if r.Name != relation.Name { + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) + } + + if r.Type != relation.Type { + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type) + } + + if r.Schema.Name != relation.Schema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) + } + + if r.FieldSchema.Name != relation.FieldSchema { + t.Errorf("schema %v field relation's schema expects %v, but got %v", s, relation.FieldSchema, r.FieldSchema.Name) + } + + if r.Polymorphic != nil { + if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID { + t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name) + } + + if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { + t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) + } + + if r.Polymorphic.Value != relation.Polymorphic.Value { + t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) + } + } + + if r.JoinTable != nil { + if r.JoinTable.Name != relation.JoinTable.Name { + t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name) + } + + if r.JoinTable.Table != relation.JoinTable.Table { + t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) + } + + for _, f := range relation.JoinTable.Fields { + checkSchemaField(t, r.JoinTable, &f, nil) + } + } + + if len(relation.References) != len(r.References) { + t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) + } + + for _, ref := range relation.References { + var found bool + for _, rf := range r.References { + if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) { + found = true + } + } + + if !found { + var refs []string + for _, rf := range r.References { + var primaryKey, primaryKeySchema string + if rf.PrimaryKey != nil { + primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name + } + refs = append(refs, fmt.Sprintf( + "{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}", + primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey, + )) + } + t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", ")) + } + } + } else { + t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) + } + }) +} + +type EmbeddedRelations struct { + Relations map[string]Relation + EmbeddedRelations map[string]EmbeddedRelations +} + +func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) { + for name, relations := range actual { + rs := expected[name] + t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) { + if len(relations.Relations) != len(rs.Relations) { + t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations)) + } + if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) { + t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations)) + } + for n, rel := range relations.Relations { + if r, ok := rs.Relations[n]; !ok { + t.Errorf("failed to find relation by name %s", n) + } else { + checkSchemaRelation(t, &schema.Schema{ + Relationships: schema.Relationships{ + Relations: map[string]*schema.Relationship{n: rel}, + }, + }, r) + } + } + checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations) + }) + } +} + +func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { + for k, v := range values { + t.Run("CheckField/"+k, func(t *testing.T) { + fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value) + tests.AssertEqual(t, v, fv) + }) + } +} diff --git a/schema/schema_test.go b/schema/schema_test.go new file mode 100644 index 00000000..5bc0fb83 --- /dev/null +++ b/schema/schema_test.go @@ -0,0 +1,336 @@ +package schema_test + +import ( + "strings" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestParseSchema(t *testing.T) { + user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user, got error %v", err) + } + + checkUserSchema(t, user) +} + +func TestParseSchemaWithPointerFields(t *testing.T) { + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + checkUserSchema(t, user) +} + +func checkUserSchema(t *testing.T, user *schema.Schema) { + // check schema + checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"}) + + // check fields + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, + {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, + {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, + {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, + {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint, Size: 64}, + {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, + {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int, Size: 64}, + {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint, Size: 64}, + {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, + } + + for _, f := range fields { + checkSchemaField(t, user, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } + + // check relations + relations := []Relation{ + { + Name: "Account", Type: schema.HasOne, Schema: "User", FieldSchema: "Account", + References: []Reference{{"ID", "User", "UserID", "Account", "", true}}, + }, + { + Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", + References: []Reference{{"ID", "User", "UserID", "Pet", "", true}}, + }, + { + Name: "Toys", Type: schema.HasMany, Schema: "User", FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{{"ID", "User", "OwnerID", "Toy", "", true}, {"", "", "OwnerType", "Toy", "users", false}}, + }, + { + Name: "Company", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Company", + References: []Reference{{"ID", "Company", "CompanyID", "User", "", false}}, + }, + { + Name: "Manager", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "ManagerID", "User", "", false}}, + }, + { + Name: "Team", Type: schema.HasMany, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "ManagerID", "User", "", true}}, + }, + { + Name: "Languages", Type: schema.Many2Many, Schema: "User", FieldSchema: "Language", + JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ + { + Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, + }, + { + Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, + }, + }}, + References: []Reference{{"ID", "User", "UserID", "UserSpeak", "", true}, {"Code", "Language", "LanguageCode", "UserSpeak", "", false}}, + }, + { + Name: "Friends", Type: schema.Many2Many, Schema: "User", FieldSchema: "User", + JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ + { + Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, + }, + { + Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, + }, + }}, + References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, + }, + } + + for _, relation := range relations { + checkSchemaRelation(t, user, relation) + } +} + +func TestParseSchemaWithAdvancedDataType(t *testing.T) { + user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + // check schema + checkSchema(t, user, schema.Schema{Name: "AdvancedDataTypeUser", Table: "advanced_data_type_users"}, []string{"ID"}) + + // check fields + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, + {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, + {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, + {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"DeletedAt"}, DataType: schema.Time}, + {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, + {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, + } + + for _, f := range fields { + checkSchemaField(t, user, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } +} + +type CustomizeTable struct{} + +func (CustomizeTable) TableName() string { + return "customize" +} + +func TestCustomizeTableName(t *testing.T) { + customize, err := schema.Parse(&CustomizeTable{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + if customize.Table != "customize" { + t.Errorf("Failed to customize table with TableName method") + } +} + +func TestNestedModel(t *testing.T) { + versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse nested user, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64}, + {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, + } + + for _, f := range fields { + checkSchemaField(t, versionUser, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } +} + +func TestEmbeddedStruct(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + Ignored string `gorm:"-"` + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } + }) + } +} + +type CustomizedNamingStrategy struct { + schema.NamingStrategy +} + +func (ns CustomizedNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} + +func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + Ignored string `gorm:"-"` + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } + }) + } +} + +func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) { + type Product struct { + ProductID uint `gorm:"primaryKey;autoIncrement"` + LanguageCode uint `gorm:"primaryKey"` + Code string + Name string + } + type ProductNonAutoIncrement struct { + ProductID uint `gorm:"primaryKey;autoIncrement:false"` + LanguageCode uint `gorm:"primaryKey"` + Code string + Name string + } + + product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) + } + + prioritizedPrimaryField := schema.Field{ + Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"}, + } + + product.Fields = []*schema.Field{product.PrioritizedPrimaryField} + + checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + + productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err) + } + + if productNonAutoIncrement.PrioritizedPrimaryField != nil { + t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil") + } +} diff --git a/schema/serializer.go b/schema/serializer.go new file mode 100644 index 00000000..397edff0 --- /dev/null +++ b/schema/serializer.go @@ -0,0 +1,170 @@ +package schema + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "encoding/gob" + "encoding/json" + "fmt" + "reflect" + "strings" + "sync" + "time" +) + +var serializerMap = sync.Map{} + +// RegisterSerializer register serializer +func RegisterSerializer(name string, serializer SerializerInterface) { + serializerMap.Store(strings.ToLower(name), serializer) +} + +// GetSerializer get serializer +func GetSerializer(name string) (serializer SerializerInterface, ok bool) { + v, ok := serializerMap.Load(strings.ToLower(name)) + if ok { + serializer, ok = v.(SerializerInterface) + } + return serializer, ok +} + +func init() { + RegisterSerializer("json", JSONSerializer{}) + RegisterSerializer("unixtime", UnixSecondSerializer{}) + RegisterSerializer("gob", GobSerializer{}) +} + +// Serializer field value serializer +type serializer struct { + Field *Field + Serializer SerializerInterface + SerializeValuer SerializerValuerInterface + Destination reflect.Value + Context context.Context + value interface{} + fieldValue interface{} +} + +// Scan implements sql.Scanner interface +func (s *serializer) Scan(value interface{}) error { + s.value = value + return nil +} + +// Value implements driver.Valuer interface +func (s serializer) Value() (driver.Value, error) { + return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) +} + +// SerializerInterface serializer interface +type SerializerInterface interface { + Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error + SerializerValuerInterface +} + +// SerializerValuerInterface serializer valuer interface +type SerializerValuerInterface interface { + Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) +} + +// JSONSerializer json serializer +type JSONSerializer struct{} + +// Scan implements serializer interface +func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) + } + + if len(bytes) > 0 { + err = json.Unmarshal(bytes, fieldValue.Interface()) + } + } + + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + result, err := json.Marshal(fieldValue) + if string(result) == "null" { + if field.TagSettings["NOT NULL"] != "" { + return "", nil + } + return nil, err + } + return string(result), err +} + +// UnixSecondSerializer json serializer +type UnixSecondSerializer struct{} + +// Scan implements serializer interface +func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + t := sql.NullTime{} + if err = t.Scan(dbValue); err == nil && t.Valid { + err = field.Set(ctx, dst, t.Time.Unix()) + } + + return +} + +// Value implements serializer interface +func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { + rv := reflect.ValueOf(fieldValue) + switch v := fieldValue.(type) { + case int64, int, uint, uint64, int32, uint32, int16, uint16: + result = time.Unix(reflect.Indirect(rv).Int(), 0) + case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + if rv.IsZero() { + return nil, nil + } + result = time.Unix(reflect.Indirect(rv).Int(), 0) + default: + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + } + return +} + +// GobSerializer gob serializer +type GobSerializer struct{} + +// Scan implements serializer interface +func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytesValue []byte + switch v := dbValue.(type) { + case []byte: + bytesValue = v + default: + return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) + } + if len(bytesValue) > 0 { + decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) + err = decoder.Decode(fieldValue.Interface()) + } + } + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(fieldValue) + return buf.Bytes(), err +} diff --git a/schema/utils.go b/schema/utils.go new file mode 100644 index 00000000..aaef7741 --- /dev/null +++ b/schema/utils.go @@ -0,0 +1,208 @@ +package schema + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +var embeddedCacheKey = "embedded_cache_store" + +func ParseTagSetting(str string, sep string) map[string]string { + settings := map[string]string{} + names := strings.Split(str, sep) // 按风格符分隔注解内容 + + for i := 0; i < len(names); i++ { + j := i + if len(names[j]) > 0 { // 跳过空内容(两个分隔符紧挨着)或者是注解是空的 + for { + if names[j][len(names[j])-1] == '\\' { // 如果第j行最后一个字符是 \, 和下一行合并 + i++ + names[j] = names[j][0:len(names[j])-1] + sep + names[i] + names[i] = "" + } else { + break + } + } + } + + values := strings.Split(names[j], ":") // 将解析出来的一组注解再使用 : 分隔 + k := strings.TrimSpace(strings.ToUpper(values[0])) // 将第一部分转大写,作为 k + + if len(values) >= 2 { // 如果是一对,就将 : 前面的部分作为 k, 后面的部分作为 Value, 存储到 settings 里面 + settings[k] = strings.Join(values[1:], ":") + } else if k != "" { + settings[k] = k // 如果没有一对,则将 value 也存成 k, 存储到 settings 里面 + } + } + + return settings +} + +func toColumns(val string) (results []string) { + if val != "" { + for _, v := range strings.Split(val, ",") { + results = append(results, strings.TrimSpace(v)) + } + } + return +} + +func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag { + for _, name := range names { + tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) + } + return tag +} + +func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag { + t := tag.Get("gorm") + if strings.Contains(t, value) { + return tag + } + return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t)) +} + +// GetRelationsValues get relations's values from a reflect value +func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { + for _, rel := range rels { + reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) + + appendToResults := func(value reflect.Value) { + if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value)) + switch result.Kind() { + case reflect.Struct: + reflectResults = reflect.Append(reflectResults, result.Addr()) + case reflect.Slice, reflect.Array: + for i := 0; i < result.Len(); i++ { + if elem := result.Index(i); elem.Kind() == reflect.Ptr { + reflectResults = reflect.Append(reflectResults, elem) + } else { + reflectResults = reflect.Append(reflectResults, elem.Addr()) + } + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Struct: + appendToResults(reflectValue) + case reflect.Slice: + for i := 0; i < reflectValue.Len(); i++ { + appendToResults(reflectValue.Index(i)) + } + } + + reflectValue = reflectResults + } + + return +} + +// GetIdentityFieldValuesMap get identity map from fields +func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + var ( + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + loaded = map[interface{}]bool{} + notZero, zero bool + ) + + switch reflectValue.Kind() { + case reflect.Struct: + results = [][]interface{}{make([]interface{}, len(fields))} + + for idx, field := range fields { + results[0][idx], zero = field.ValueOf(ctx, reflectValue) + notZero = notZero || !zero + } + + if !notZero { + return nil, nil + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + elem := reflectValue.Index(i) + elemKey := elem.Interface() + if elem.Kind() != reflect.Ptr && elem.CanAddr() { + elemKey = elem.Addr().Interface() + } + + if _, ok := loaded[elemKey]; ok { + continue + } + loaded[elemKey] = true + + fieldValues := make([]interface{}, len(fields)) + notZero = false + for idx, field := range fields { + fieldValues[idx], zero = field.ValueOf(ctx, elem) + notZero = notZero || !zero + } + + if notZero { + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + results = append(results, fieldValues) + dataResults[dataKey] = []reflect.Value{elem} + } else { + dataResults[dataKey] = append(dataResults[dataKey], elem) + } + } + } + } + + return dataResults, results +} + +// GetIdentityFieldValuesMapFromValues get identity map from fields +func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + resultsMap := map[string][]reflect.Value{} + results := [][]interface{}{} + + for _, v := range values { + rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields) + for k, v := range rm { + resultsMap[k] = append(resultsMap[k], v...) + } + results = append(results, rs...) + } + return resultsMap, results +} + +// ToQueryValues to query values +func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { + queryValues := make([]interface{}, len(foreignValues)) + if len(foreignKeys) == 1 { + for idx, r := range foreignValues { + queryValues[idx] = r[0] + } + + return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues + } + + columns := make([]clause.Column, len(foreignKeys)) + for idx, key := range foreignKeys { + columns[idx] = clause.Column{Table: table, Name: key} + } + + for idx, r := range foreignValues { + queryValues[idx] = r + } + + return columns, queryValues +} + +type embeddedNamer struct { + Table string + Namer +} diff --git a/schema/utils_test.go b/schema/utils_test.go new file mode 100644 index 00000000..1b47ef25 --- /dev/null +++ b/schema/utils_test.go @@ -0,0 +1,24 @@ +package schema + +import ( + "reflect" + "testing" +) + +func TestRemoveSettingFromTag(t *testing.T) { + tags := map[string]string{ + `gorm:"before:value;column:db;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db;" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, + `gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, + } + + for k, v := range tags { + if string(removeSettingFromTag(reflect.StructTag(k), "column")) != v { + t.Errorf("%v after removeSettingFromTag should equal %v, but got %v", k, v, removeSettingFromTag(reflect.StructTag(k), "column")) + } + } +} diff --git a/scope.go b/scope.go deleted file mode 100644 index 23a5701b..00000000 --- a/scope.go +++ /dev/null @@ -1,1270 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "regexp" - "strconv" - "strings" - "time" - - "reflect" -) - -// Scope contain current operation's information when you perform any operation on the database -type Scope struct { - Search *search - Value interface{} - SQL string - SQLVars []interface{} - db *DB - instanceID string - primaryKeyField *Field - skipLeft bool - fields *[]*Field - selectAttrs *[]string -} - -// IndirectValue return scope's reflect value's indirect value -func (scope *Scope) IndirectValue() reflect.Value { - return indirect(reflect.ValueOf(scope.Value)) -} - -// New create a new Scope without search information -func (scope *Scope) New(value interface{}) *Scope { - return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} -} - -//////////////////////////////////////////////////////////////////////////////// -// Scope DB -//////////////////////////////////////////////////////////////////////////////// - -// DB return scope's DB connection -func (scope *Scope) DB() *DB { - return scope.db -} - -// NewDB create a new DB without search information -func (scope *Scope) NewDB() *DB { - if scope.db != nil { - db := scope.db.clone() - db.search = nil - db.Value = nil - return db - } - return nil -} - -// SQLDB return *sql.DB -func (scope *Scope) SQLDB() sqlCommon { - return scope.db.db -} - -// Dialect get dialect -func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect -} - -// Quote used to quote string to escape them for database -func (scope *Scope) Quote(str string) string { - if strings.Index(str, ".") != -1 { - newStrs := []string{} - for _, str := range strings.Split(str, ".") { - newStrs = append(newStrs, scope.Dialect().Quote(str)) - } - return strings.Join(newStrs, ".") - } - - return scope.Dialect().Quote(str) -} - -// Err add error to Scope -func (scope *Scope) Err(err error) error { - if err != nil { - scope.db.AddError(err) - } - return err -} - -// HasError check if there are any error -func (scope *Scope) HasError() bool { - return scope.db.Error != nil -} - -// Log print log message -func (scope *Scope) Log(v ...interface{}) { - scope.db.log(v...) -} - -// SkipLeft skip remaining callbacks -func (scope *Scope) SkipLeft() { - scope.skipLeft = true -} - -// Fields get value's fields -func (scope *Scope) Fields() []*Field { - if scope.fields == nil { - var ( - fields []*Field - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) - - for _, structField := range scope.GetModelStruct().StructFields { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) - } else { - fields = append(fields, &Field{StructField: structField, IsBlank: true}) - } - } - scope.fields = &fields - } - - return *scope.fields -} - -// FieldByName find `gorm.Field` with field name or db name -func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - - for _, field := range scope.Fields() { - if field.Name == name || field.DBName == name { - return field, true - } - if field.DBName == dbName { - mostMatchedField = field - } - } - return mostMatchedField, mostMatchedField != nil -} - -// PrimaryFields return scope's primary fields -func (scope *Scope) PrimaryFields() (fields []*Field) { - for _, field := range scope.Fields() { - if field.IsPrimaryKey { - fields = append(fields, field) - } - } - return fields -} - -// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one -func (scope *Scope) PrimaryField() *Field { - if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { - if len(primaryFields) > 1 { - if field, ok := scope.FieldByName("id"); ok { - return field - } - } - return scope.PrimaryFields()[0] - } - return nil -} - -// PrimaryKey get main primary field's db name -func (scope *Scope) PrimaryKey() string { - if field := scope.PrimaryField(); field != nil { - return field.DBName - } - return "" -} - -// PrimaryKeyZero check main primary field's value is blank or not -func (scope *Scope) PrimaryKeyZero() bool { - field := scope.PrimaryField() - return field == nil || field.IsBlank -} - -// PrimaryKeyValue get the primary key's value -func (scope *Scope) PrimaryKeyValue() interface{} { - if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { - return field.Field.Interface() - } - return 0 -} - -// HasColumn to check if has column -func (scope *Scope) HasColumn(column string) bool { - for _, field := range scope.GetStructFields() { - if field.IsNormal && (field.Name == column || field.DBName == column) { - return true - } - } - return false -} - -// SetColumn to set the column's value, column could be field or field's name/dbname -func (scope *Scope) SetColumn(column interface{}, value interface{}) error { - var updateAttrs = map[string]interface{}{} - if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - updateAttrs = attrs.(map[string]interface{}) - defer scope.InstanceSet("gorm:update_attrs", updateAttrs) - } - - if field, ok := column.(*Field); ok { - updateAttrs[field.DBName] = value - return field.Set(value) - } else if name, ok := column.(string); ok { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - for _, field := range scope.Fields() { - if field.DBName == value { - updateAttrs[field.DBName] = value - return field.Set(value) - } - if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { - mostMatchedField = field - } - } - - if mostMatchedField != nil { - updateAttrs[mostMatchedField.DBName] = value - return mostMatchedField.Set(value) - } - } - return errors.New("could not convert column to field") -} - -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one -func (scope *Scope) CallMethod(methodName string) { - if scope.Value == nil { - return - } - - if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { - for i := 0; i < indirectScopeValue.Len(); i++ { - scope.callMethod(methodName, indirectScopeValue.Index(i)) - } - } else { - scope.callMethod(methodName, indirectScopeValue) - } -} - -// AddToVars add value as sql's vars, used to prevent SQL injection -func (scope *Scope) AddToVars(value interface{}) string { - if expr, ok := value.(*expr); ok { - exp := expr.expr - for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - return exp - } - - scope.SQLVars = append(scope.SQLVars, value) - return scope.Dialect().BindVar(len(scope.SQLVars)) -} - -// SelectAttrs return selected attributes -func (scope *Scope) SelectAttrs() []string { - if scope.selectAttrs == nil { - attrs := []string{} - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]string); ok { - attrs = append(attrs, strs...) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) - } - } - } - scope.selectAttrs = &attrs - } - return *scope.selectAttrs -} - -// OmitAttrs return omitted attributes -func (scope *Scope) OmitAttrs() []string { - return scope.Search.omits -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*DB) string -} - -// TableName return table name -func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName - } - - if tabler, ok := scope.Value.(tabler); ok { - return tabler.TableName() - } - - if tabler, ok := scope.Value.(dbTabler); ok { - return tabler.TableName(scope.db) - } - - return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) -} - -// QuotedTableName return quoted table name -func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Index(scope.Search.tableName, " ") != -1 { - return scope.Search.tableName - } - return scope.Quote(scope.Search.tableName) - } - - return scope.Quote(scope.TableName()) -} - -// CombinedConditionSql return combined condition sql -func (scope *Scope) CombinedConditionSql() string { - return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() -} - -// Raw set raw sql -func (scope *Scope) Raw(sql string) *Scope { - scope.SQL = strings.Replace(sql, "$$", "?", -1) - return scope -} - -// Exec perform generated SQL -func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) - - if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { - scope.db.RowsAffected = count - } - } - } - return scope -} - -// Set set value by name -func (scope *Scope) Set(name string, value interface{}) *Scope { - scope.db.InstantSet(name, value) - return scope -} - -// Get get setting by name -func (scope *Scope) Get(name string) (interface{}, bool) { - return scope.db.Get(name) -} - -// InstanceID get InstanceID for scope -func (scope *Scope) InstanceID() string { - if scope.instanceID == "" { - scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) - } - return scope.instanceID -} - -// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback -func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { - return scope.Set(name+scope.InstanceID(), value) -} - -// InstanceGet get instance setting from current operation -func (scope *Scope) InstanceGet(name string) (interface{}, bool) { - return scope.Get(name + scope.InstanceID()) -} - -// Begin start a transaction -func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { - scope.db.db = interface{}(tx).(sqlCommon) - scope.InstanceSet("gorm:started_transaction", true) - } - } - return scope -} - -// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it -func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { - if db, ok := scope.db.db.(sqlTx); ok { - if scope.HasError() { - db.Rollback() - } else { - scope.Err(db.Commit()) - } - scope.db.db = scope.db.parent.db - } - } - return scope -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.Scope -//////////////////////////////////////////////////////////////////////////////// - -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { - // Only get address from non-pointer - if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { - reflectValue = reflectValue.Addr() - } - - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { - switch method := methodValue.Interface().(type) { - case func(): - method() - case func(*Scope): - method(scope) - case func(*DB): - newDB := scope.NewDB() - method(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(method()) - case func(*Scope) error: - scope.Err(method(scope)) - case func(*DB) error: - newDB := scope.NewDB() - scope.Err(method(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) - } - } -} - -var columnRegexp = regexp.MustCompile("^[a-zA-Z]+(\\.[a-zA-Z]+)*$") // only match string like `name`, `users.name` - -func (scope *Scope) quoteIfPossible(str string) string { - if columnRegexp.MatchString(str) { - return scope.Quote(str) - } - return str -} - -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { - var ( - ignored interface{} - selectFields []*Field - values = make([]interface{}, len(columns)) - selectedColumnsMap = map[string]int{} - resetFields = map[*Field]int{} - ) - - for index, column := range columns { - values[index] = &ignored - - selectFields = fields - if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] - } - - for fieldIndex, field := range selectFields { - if field.DBName == column { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() - resetFields[field] = index - } - - selectedColumnsMap[column] = fieldIndex - break - } - } - } - - scope.Err(rows.Scan(values...)) - - for field, index := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - } -} - -func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) -} - -func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - // if string is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - return scope.primaryCondition(scope.AddToVars(value)) - } else if value != "" { - str = fmt.Sprintf("(%v)", value) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) - clause["args"] = []interface{}{value} - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", scope.QuotedTableName(), scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - newScope := scope.New(value) - for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", newScope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSQL string - var primaryKey = scope.PrimaryKey() - - switch value := clause["query"].(type) { - case string: - // is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSQL = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) - notEqualSQL = fmt.Sprintf("(%v.%v <> ?)", scope.QuotedTableName(), scope.Quote(value)) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(primaryKey)) - clause["args"] = []interface{}{value} - } - return "" - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", scope.QuotedTableName(), scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - var newScope = scope.New(value) - for _, field := range newScope.Fields() { - if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", newScope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() - } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - str = value - case []string: - str = strings.Join(value, ", ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) whereSQL() (sql string) { - var ( - quotedTableName = scope.QuotedTableName() - primaryConditions, andConditions, orConditions []string - ) - - if !scope.Search.Unscoped && scope.HasColumn("deleted_at") { - sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName) - primaryConditions = append(primaryConditions, sql) - } - - if !scope.PrimaryKeyZero() { - for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) - primaryConditions = append(primaryConditions, sql) - } - } - - for _, clause := range scope.Search.whereConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - for _, clause := range scope.Search.orConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - orConditions = append(orConditions, sql) - } - } - - for _, clause := range scope.Search.notConditions { - if sql := scope.buildNotCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - orSQL := strings.Join(orConditions, " OR ") - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) > 0 { - if len(orSQL) > 0 { - combinedSQL = combinedSQL + " OR " + orSQL - } - } else { - combinedSQL = orSQL - } - - if len(primaryConditions) > 0 { - sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSQL) > 0 { - sql = sql + " AND (" + combinedSQL + ")" - } - } else if len(combinedSQL) > 0 { - sql = "WHERE " + combinedSQL - } - return -} - -func (scope *Scope) selectSQL() string { - if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } - return "*" - } - return scope.buildSelectQuery(scope.Search.selects) -} - -func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.countingQuery { - return "" - } - - var orders []string - for _, order := range scope.Search.orders { - if str, ok := order.(string); ok { - orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*expr); ok { - exp := expr.expr - for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - orders = append(orders, exp) - } - } - return " ORDER BY " + strings.Join(orders, ",") -} - -func (scope *Scope) limitAndOffsetSQL() string { - return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) -} - -func (scope *Scope) groupSQL() string { - if len(scope.Search.group) == 0 { - return "" - } - return " GROUP BY " + scope.Search.group -} - -func (scope *Scope) havingSQL() string { - if len(scope.Search.havingConditions) == 0 { - return "" - } - - var andConditions []string - for _, clause := range scope.Search.havingConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) == 0 { - return "" - } - - return " HAVING " + combinedSQL -} - -func (scope *Scope) joinsSQL() string { - var joinConditions []string - for _, clause := range scope.Search.joinConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) - } - } - - return strings.Join(joinConditions, " ") + " " -} - -func (scope *Scope) prepareQuerySQL() { - if scope.Search.raw { - scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values ...interface{}) *Scope { - if len(values) > 0 { - scope.Search.Where(values[0], values[1:]...) - } - return scope -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - -func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} { - var attrs = map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - return value - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values}).Fields() { - if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { - if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false), true - } - - results = map[string]interface{}{} - - for key, value := range convertInterfaceToMap(value, true) { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*expr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal { - hasUpdate = true - if err == ErrUnaddressable { - fmt.Println(err) - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() - } - } - } - } - } - return -} - -func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySQL() - return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySQL() - return scope.SQLDB().Query(scope.SQL, scope.SQLVars...) -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(clause["query"]) - } - scope.updatedAttrsWithValues(scope.Search.initAttrs) - scope.updatedAttrsWithValues(scope.Search.assignAttrs) - return scope -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search.Select(column) - if dest.Kind() != reflect.Slice { - scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) - return scope - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - if query, ok := scope.Search.selects["query"]; !ok || !regexp.MustCompile("(?i)^count(.+)$").MatchString(fmt.Sprint(query)) { - scope.Search.Select("count(*)") - } - scope.Search.countingQuery = true - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - typ := scope.IndirectValue().Type() - - for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - return typ.Name() -} - -// trace print sql log -func (scope *Scope) trace(t time.Time) { - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - -func (scope *Scope) changeableField(field *Field) bool { - if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range scope.OmitAttrs() { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return true -} - -func (scope *Scope) shouldSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) { - return false - } - return true && !scope.HasError() -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.db.NewScope(value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField, _ := scope.FieldByName(foreignKey) - toField, _ := toScope.FieldByName(foreignKey) - - if fromField != nil { - if relationship := fromField.Relationship; relationship != nil { - if relationship.Kind == "many_to_many" { - joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) - } else if relationship.Kind == "belongs_to" { - query := toScope.db - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(foreignKey); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) - } - } - scope.Err(query.Find(value).Error) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - query := toScope.db - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) - } - scope.Err(query.Find(value).Error) - } - } else { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) - } - return scope - } else if toField != nil { - sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) - return scope - } - } - - scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) - return scope -} - -// getTableOptions return the table options string or an empty string if the table options does not exist -func (scope *Scope) getTableOptions() string { - tableOptions, ok := scope.Get("gorm:table_options") - if !ok { - return "" - } - return tableOptions.(string) -} - -func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTableHandler := relationship.JoinTableHandler - joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { - toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - - var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) - } - } - - for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) - } - } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) - } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) - } -} - -func (scope *Scope) createTable() *Scope { - var tags []string - var primaryKeys []string - var primaryKeyInColumnType = false - for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - - // Check if the primary key constraint was specified as - // part of the column type. If so, we can only support - // one column as the primary key. - if strings.Contains(strings.ToLower(sqlTag), "primary key") { - primaryKeyInColumnType = true - } - - tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) - } - - if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) - } - scope.createJoinTable(field) - } - - var primaryKeyStr string - if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) - } - - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() - - scope.autoIndex() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { - return - } - - var columns []string - for _, name := range column { - columns = append(columns, scope.quoteIfPossible(name)) - } - - sqlCreate := "CREATE INDEX" - if unique { - sqlCreate = "CREATE UNIQUE INDEX" - } - - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() -} - -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) - - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) -} - -func (scope *Scope) autoMigrate() *Scope { - tableName := scope.TableName() - quotedTableName := scope.QuotedTableName() - - if !scope.Dialect().HasTable(tableName) { - scope.createTable() - } else { - for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() - } - } - scope.createJoinTable(field) - } - scope.autoIndex() - } - return scope -} - -func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} - - for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "INDEX" || name == "" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) - } - indexes[name] = append(indexes[name], field.DBName) - } - } - - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "UNIQUE_INDEX" || name == "" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) - } - uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) - } - } - } - - for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) - } - - for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) - } - - return scope -} - -func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { - for _, value := range values { - indirectValue := reflect.ValueOf(value) - for indirectValue.Kind() == reflect.Ptr { - indirectValue = indirectValue.Elem() - } - - switch indirectValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectValue.Len(); i++ { - var result []interface{} - var object = indirect(indirectValue.Index(i)) - for _, column := range columns { - result = append(result, object.FieldByName(column).Interface()) - } - results = append(results, result) - } - case reflect.Struct: - var result []interface{} - for _, column := range columns { - result = append(result, indirectValue.FieldByName(column).Interface()) - } - results = append(results, result) - } - } - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - resultsMap := map[interface{}]bool{} - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { - resultsMap[elem.Addr()] = true - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() && resultsMap[result.Addr()] != true { - resultsMap[result.Addr()] = true - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} diff --git a/scope_test.go b/scope_test.go deleted file mode 100644 index 42458995..00000000 --- a/scope_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package gorm_test - -import ( - "github.com/jinzhu/gorm" - "testing" -) - -func NameIn1And2(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) -} - -func NameIn2And3(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) -} - -func NameIn(names []string) func(d *gorm.DB) *gorm.DB { - return func(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", names) - } -} - -func TestScopes(t *testing.T) { - user1 := User{Name: "ScopeUser1", Age: 1} - user2 := User{Name: "ScopeUser2", Age: 1} - user3 := User{Name: "ScopeUser3", Age: 2} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users1, users2, users3 []User - DB.Scopes(NameIn1And2).Find(&users1) - if len(users1) != 2 { - t.Errorf("Should found two users's name in 1, 2") - } - - DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) - if len(users2) != 1 { - t.Errorf("Should found one user's name is 2") - } - - DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3) - if len(users3) != 2 { - t.Errorf("Should found two users's name in 1, 3") - } -} diff --git a/search.go b/search.go deleted file mode 100644 index 8ddc5b29..00000000 --- a/search.go +++ /dev/null @@ -1,147 +0,0 @@ -package gorm - -import "fmt" - -type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingConditions []map[string]interface{} - joinConditions []map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []interface{} - preload []searchPreload - offset interface{} - limit interface{} - group string - tableName string - raw bool - Unscoped bool - countingQuery bool -} - -type searchPreload struct { - schema string - conditions []interface{} -} - -func (s *search) clone() *search { - clone := *s - return &clone -} - -func (s *search) Where(query interface{}, values ...interface{}) *search { - s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Not(query interface{}, values ...interface{}) *search { - s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Or(query interface{}, values ...interface{}) *search { - s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Attrs(attrs ...interface{}) *search { - s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Assign(attrs ...interface{}) *search { - s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Order(value interface{}, reorder ...bool) *search { - if len(reorder) > 0 && reorder[0] { - s.orders = []interface{}{} - } - - if value != nil { - s.orders = append(s.orders, value) - } - return s -} - -func (s *search) Select(query interface{}, args ...interface{}) *search { - s.selects = map[string]interface{}{"query": query, "args": args} - return s -} - -func (s *search) Omit(columns ...string) *search { - s.omits = columns - return s -} - -func (s *search) Limit(limit interface{}) *search { - s.limit = limit - return s -} - -func (s *search) Offset(offset interface{}) *search { - s.offset = offset - return s -} - -func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSQL(query) - return s -} - -func (s *search) Having(query string, values ...interface{}) *search { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Preload(schema string, values ...interface{}) *search { - var preloads []searchPreload - for _, preload := range s.preload { - if preload.schema != schema { - preloads = append(preloads, preload) - } - } - preloads = append(preloads, searchPreload{schema, values}) - s.preload = preloads - return s -} - -func (s *search) Raw(b bool) *search { - s.raw = b - return s -} - -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - -func (s *search) Table(name string) *search { - s.tableName = name - return s -} - -func (s *search) getInterfaceAsSQL(value interface{}) (str string) { - switch value.(type) { - case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - str = fmt.Sprintf("%v", value) - default: - s.db.AddError(ErrInvalidSQL) - } - - if str == "-1" { - return "" - } - return -} diff --git a/search_test.go b/search_test.go deleted file mode 100644 index 4db7ab6a..00000000 --- a/search_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm - -import ( - "reflect" - "testing" -) - -func TestCloneSearch(t *testing.T) { - s := new(search) - s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age") - - s1 := s.clone() - s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email") - - if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { - t.Errorf("Where should be copied") - } - - if reflect.DeepEqual(s.orders, s1.orders) { - t.Errorf("Order should be copied") - } - - if reflect.DeepEqual(s.initAttrs, s1.initAttrs) { - t.Errorf("InitAttrs should be copied") - } - - if reflect.DeepEqual(s.Select, s1.Select) { - t.Errorf("selectStr should be copied") - } -} diff --git a/soft_delete.go b/soft_delete.go new file mode 100644 index 00000000..5673d3b8 --- /dev/null +++ b/soft_delete.go @@ -0,0 +1,170 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "reflect" + + "github.com/jinzhu/now" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +type DeletedAt sql.NullTime + +// Scan implements the Scanner interface. +func (n *DeletedAt) Scan(value interface{}) error { + return (*sql.NullTime)(n).Scan(value) +} + +// Value implements the driver Valuer interface. +func (n DeletedAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time, nil +} + +func (n DeletedAt) MarshalJSON() ([]byte, error) { + if n.Valid { + return json.Marshal(n.Time) + } + return json.Marshal(nil) +} + +func (n *DeletedAt) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + n.Valid = false + return nil + } + err := json.Unmarshal(b, &n.Time) + if err == nil { + n.Valid = true + } + return err +} + +func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}} +} + +func parseZeroValueTag(f *schema.Field) sql.NullString { + if v, ok := f.TagSettings["ZEROVALUE"]; ok { + if _, err := now.Parse(v); err == nil { + return sql.NullString{String: v, Valid: true} + } + } + return sql.NullString{Valid: false} +} + +type SoftDeleteQueryClause struct { + ZeroValue sql.NullString + Field *schema.Field +} + +func (sd SoftDeleteQueryClause) Name() string { + return "" +} + +func (sd SoftDeleteQueryClause) Build(clause.Builder) { +} + +func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped { + if c, ok := stmt.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 { + for _, expr := range where.Exprs { + if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { + where.Exprs = []clause.Expression{clause.And(where.Exprs...)} + c.Expression = where + stmt.Clauses["WHERE"] = c + break + } + } + } + } + + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue}, + }}) + stmt.Clauses["soft_delete_enabled"] = clause.Clause{} + } +} + +func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}} +} + +type SoftDeleteUpdateClause struct { + ZeroValue sql.NullString + Field *schema.Field +} + +func (sd SoftDeleteUpdateClause) Name() string { + return "" +} + +func (sd SoftDeleteUpdateClause) Build(clause.Builder) { +} + +func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { + SoftDeleteQueryClause(sd).ModifyStatement(stmt) + } +} + +func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}} +} + +type SoftDeleteDeleteClause struct { + ZeroValue sql.NullString + Field *schema.Field +} + +func (sd SoftDeleteDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { + curTime := stmt.DB.NowFunc() + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) + stmt.SetColumn(sd.Field.DBName, curTime, true) + + if stmt.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + SoftDeleteQueryClause(sd).ModifyStatement(stmt) + stmt.AddClauseIfNotExists(clause.Update{}) + stmt.Build(stmt.DB.Callback().Update().Clauses...) + } +} diff --git a/statement.go b/statement.go new file mode 100644 index 00000000..59c0b772 --- /dev/null +++ b/statement.go @@ -0,0 +1,728 @@ +package gorm + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "sort" + "strconv" + "strings" + "sync" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Statement statement +type Statement struct { + *DB + TableExpr *clause.Expr + Table string + Model interface{} + Unscoped bool + Dest interface{} + ReflectValue reflect.Value + Clauses map[string]clause.Clause + BuildClauses []string + Distinct bool + Selects []string // selected columns + Omits []string // omit columns + Joins []join + Preloads map[string][]interface{} + Settings sync.Map + ConnPool ConnPool + Schema *schema.Schema + Context context.Context + RaiseErrorOnNotFound bool + SkipHooks bool + SQL strings.Builder + Vars []interface{} + CurDestIndex int + attrs []interface{} + assigns []interface{} + scopes []func(*DB) *DB +} + +type join struct { + Name string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string + JoinType clause.JoinType +} + +// StatementModifier statement modifier interface +type StatementModifier interface { + ModifyStatement(*Statement) +} + +// WriteString write string +func (stmt *Statement) WriteString(str string) (int, error) { + return stmt.SQL.WriteString(str) +} + +// WriteByte write byte +func (stmt *Statement) WriteByte(c byte) error { + return stmt.SQL.WriteByte(c) +} + +// WriteQuoted write quoted value +func (stmt *Statement) WriteQuoted(value interface{}) { + stmt.QuoteTo(&stmt.SQL, value) +} + +// QuoteTo write quoted value to writer +func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { + write := func(raw bool, str string) { + if raw { + writer.WriteString(str) + } else { + stmt.DB.Dialector.QuoteTo(writer, str) + } + } + + switch v := field.(type) { + case clause.Table: + if v.Name == clause.CurrentTable { + if stmt.TableExpr != nil { + stmt.TableExpr.Build(stmt) + } else { + write(v.Raw, stmt.Table) + } + } else { + write(v.Raw, v.Name) + } + + if v.Alias != "" { + writer.WriteByte(' ') + write(v.Raw, v.Alias) + } + case clause.Column: + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') + } + + if v.Name == clause.PrimaryKey { + if stmt.Schema == nil { + stmt.DB.AddError(ErrModelValueRequired) + } else if stmt.Schema.PrioritizedPrimaryField != nil { + write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if len(stmt.Schema.DBNames) > 0 { + write(v.Raw, stmt.Schema.DBNames[0]) + } else { + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck + } + } else { + write(v.Raw, v.Name) + } + + if v.Alias != "" { + writer.WriteString(" AS ") + write(v.Raw, v.Alias) + } + case []clause.Column: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteByte(',') + } + stmt.QuoteTo(writer, d) + } + writer.WriteByte(')') + case clause.Expr: + v.Build(stmt) + case string: + stmt.DB.Dialector.QuoteTo(writer, v) + case []string: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteByte(',') + } + stmt.DB.Dialector.QuoteTo(writer, d) + } + writer.WriteByte(')') + default: + stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) + } +} + +// Quote returns quoted value +func (stmt *Statement) Quote(field interface{}) string { + var builder strings.Builder + stmt.QuoteTo(&builder, field) + return builder.String() +} + +// AddVar add var +func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { + for idx, v := range vars { + if idx > 0 { + writer.WriteByte(',') + } + + switch v := v.(type) { + case sql.NamedArg: + stmt.Vars = append(stmt.Vars, v.Value) + case clause.Column, clause.Table: + stmt.QuoteTo(writer, v) + case Valuer: + reflectValue := reflect.ValueOf(v) + if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() { + stmt.AddVar(writer, nil) + } else { + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) + } + case clause.Interface: + c := clause.Clause{Name: v.Name()} + v.MergeClause(&c) + c.Build(stmt) + case clause.Expression: + v.Build(stmt) + case driver.Valuer: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + case []byte: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + case []interface{}: + if len(v) > 0 { + writer.WriteByte('(') + stmt.AddVar(writer, v...) + writer.WriteByte(')') + } else { + writer.WriteString("(NULL)") + } + case *DB: + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() + if v.Statement.SQL.Len() > 0 { + var ( + vars = subdb.Statement.Vars + sql = v.Statement.SQL.String() + ) + + subdb.Statement.Vars = make([]interface{}, 0, len(vars)) + for _, vv := range vars { + subdb.Statement.Vars = append(subdb.Statement.Vars, vv) + bindvar := strings.Builder{} + v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + sql = strings.Replace(sql, bindvar.String(), "?", 1) + } + + subdb.Statement.SQL.Reset() + subdb.Statement.Vars = stmt.Vars + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } else { + clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } + } else { + subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) + subdb.callbacks.Query().Execute(subdb) + } + + writer.WriteString(subdb.Statement.SQL.String()) + stmt.Vars = subdb.Statement.Vars + default: + switch rv := reflect.ValueOf(v); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + writer.WriteString("(NULL)") + } else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) { + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + } else { + writer.WriteByte('(') + for i := 0; i < rv.Len(); i++ { + if i > 0 { + writer.WriteByte(',') + } + stmt.AddVar(writer, rv.Index(i).Interface()) + } + writer.WriteByte(')') + } + default: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + } + } + } +} + +// AddClause add clause +func (stmt *Statement) AddClause(v clause.Interface) { + if optimizer, ok := v.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) + } else { + name := v.Name() + c := stmt.Clauses[name] + c.Name = name + v.MergeClause(&c) + stmt.Clauses[name] = c + } +} + +// AddClauseIfNotExists add clause if not exists +func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { + if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil { + stmt.AddClause(v) + } +} + +// BuildCondition build condition +func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression { + if s, ok := query.(string); ok { + // if it is a number, then treats it as primary key + if _, err := strconv.Atoi(s); err != nil { + if s == "" && len(args) == 0 { + return nil + } + + if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { + // looks like a where condition + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } + + if len(args) > 0 && strings.Contains(s, "@") { + // looks like a named query + return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} + } + + if strings.Contains(strings.TrimSpace(s), " ") { + // looks like a where condition + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } + + if len(args) == 1 { + return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} + } + } + } + + conds := make([]clause.Expression, 0, 4) + args = append([]interface{}{query}, args...) + for idx, arg := range args { + if arg == nil { + continue + } + if valuer, ok := arg.(driver.Valuer); ok { + arg, _ = valuer.Value() + } + + switch v := arg.(type) { + case clause.Expression: + conds = append(conds, v) + case *DB: + v.executeScopes() + + if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + if where, ok := cs.Expression.(clause.Where); ok { + if len(where.Exprs) == 1 { + if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { + where.Exprs[0] = clause.AndConditions(orConds) + } + } + conds = append(conds, clause.And(where.Exprs...)) + } else { + conds = append(conds, cs.Expression) + } + if v.Statement == stmt { + cs.Expression = nil + stmt.Statement.Clauses["WHERE"] = cs + } + } + case map[interface{}]interface{}: + for i, j := range v { + conds = append(conds, clause.Eq{Column: i, Value: j}) + } + case map[string]string: + keys := make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } + case map[string]interface{}: + keys := make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if _, ok := v[key].(driver.Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else if _, ok := v[key].(Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else { + // optimize reflect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { + values[i] = reflectValue.Index(i).Interface() + } + + conds = append(conds, clause.IN{Column: key, Values: values}) + } + default: + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } + } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + for reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + selectedColumns := map[string]bool{} + if idx == 0 { + for _, v := range args[1:] { + if vs, ok := v.(string); ok { + selectedColumns[vs] = true + } + } + } + restricted := len(selectedColumns) != 0 + + switch reflectValue.Kind() { + case reflect.Struct: + for _, field := range s.Fields { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } + } + } + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for _, field := range s.Fields { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } + } + } + } + } + } + + if restricted { + break + } + } else if !reflectValue.IsValid() { + stmt.AddError(ErrInvalidData) + } else if len(conds) == 0 { + if len(args) == 1 { + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + // optimize reflect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { + values[i] = reflectValue.Index(i).Interface() + } + + if len(values) > 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + } + return conds + } + } + + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + } + } + } + + return conds +} + +// Build build sql with clauses names +func (stmt *Statement) Build(clauses ...string) { + var firstClauseWritten bool + + for _, name := range clauses { + if c, ok := stmt.Clauses[name]; ok { + if firstClauseWritten { + stmt.WriteByte(' ') + } + + firstClauseWritten = true + if b, ok := stmt.DB.ClauseBuilders[name]; ok { + b(c, stmt) + } else { + c.Build(stmt) + } + } + } +} + +func (stmt *Statement) Parse(value interface{}) (err error) { + return stmt.ParseWithSpecialTableName(value, "") +} + +func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { + if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { + if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { + stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} + stmt.Table = tables[1] + return + } + + stmt.Table = stmt.Schema.Table + } + return err +} + +func (stmt *Statement) clone() *Statement { + newStmt := &Statement{ + TableExpr: stmt.TableExpr, + Table: stmt.Table, + Model: stmt.Model, + Unscoped: stmt.Unscoped, + Dest: stmt.Dest, + ReflectValue: stmt.ReflectValue, + Clauses: map[string]clause.Clause{}, + Distinct: stmt.Distinct, + Selects: stmt.Selects, + Omits: stmt.Omits, + Preloads: map[string][]interface{}{}, + ConnPool: stmt.ConnPool, + Schema: stmt.Schema, + Context: stmt.Context, + RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, + SkipHooks: stmt.SkipHooks, + } + + if stmt.SQL.Len() > 0 { + newStmt.SQL.WriteString(stmt.SQL.String()) + newStmt.Vars = make([]interface{}, 0, len(stmt.Vars)) + newStmt.Vars = append(newStmt.Vars, stmt.Vars...) + } + + for k, c := range stmt.Clauses { + newStmt.Clauses[k] = c + } + + for k, p := range stmt.Preloads { + newStmt.Preloads[k] = p + } + + if len(stmt.Joins) > 0 { + newStmt.Joins = make([]join, len(stmt.Joins)) + copy(newStmt.Joins, stmt.Joins) + } + + if len(stmt.scopes) > 0 { + newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes)) + copy(newStmt.scopes, stmt.scopes) + } + + stmt.Settings.Range(func(k, v interface{}) bool { + newStmt.Settings.Store(k, v) + return true + }) + + return newStmt +} + +// SetColumn set column's value +// +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + v[name] = value + } else if v, ok := stmt.Dest.([]map[string]interface{}); ok { + for _, m := range v { + m[name] = value + } + } else if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + if stmt.ReflectValue != destValue { + if !destValue.CanAddr() { + destValueCanAddr := reflect.New(destValue.Type()) + destValueCanAddr.Elem().Set(destValue) + stmt.Dest = destValueCanAddr.Interface() + destValue = destValueCanAddr.Elem() + } + + switch destValue.Kind() { + case reflect.Struct: + stmt.AddError(field.Set(stmt.Context, destValue, value)) + default: + stmt.AddError(ErrInvalidData) + } + } + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(fromCallbacks) > 0 { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)) + } + } else { + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)) + } + case reflect.Struct: + if !stmt.ReflectValue.CanAddr() { + stmt.AddError(ErrInvalidValue) + return + } + + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value)) + } + } else { + stmt.AddError(ErrInvalidField) + } + } else { + stmt.AddError(ErrInvalidField) + } +} + +// Changed check model changed or not when updating +func (stmt *Statement) Changed(fields ...string) bool { + modelValue := stmt.ReflectValue + switch modelValue.Kind() { + case reflect.Slice, reflect.Array: + modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) + } + + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) + changed := func(field *schema.Field) bool { + fieldValue, _ := field.ValueOf(stmt.Context, modelValue) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if mv, mok := stmt.Dest.(map[string]interface{}); mok { + if fv, ok := mv[field.Name]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if fv, ok := mv[field.DBName]; ok { + return !utils.AssertEqual(fv, fieldValue) + } + } else { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + changedValue, zero := field.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } + } + return false + } + + if len(fields) == 0 { + for _, field := range stmt.Schema.FieldsByDBName { + if changed(field) { + return true + } + } + } else { + for _, name := range fields { + if field := stmt.Schema.LookUpField(name); field != nil { + if changed(field) { + return true + } + } + } + } + + return false +} + +var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) + +// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false +func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { + results := map[string]bool{} + notRestricted := false + + processColumn := func(column string, result bool) { + if stmt.Schema == nil { + results[column] = result + } else if column == "*" { + notRestricted = result + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = result + } + } else if column == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = result + } + } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + results[field.DBName] = result + } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") { + if matches[2] == "*" { + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = result + } + } else { + results[matches[2]] = result + } + } else { + results[column] = result + } + } + + // select columns + for _, column := range stmt.Selects { + processColumn(column, true) + } + + // omit columns + for _, column := range stmt.Omits { + processColumn(column, false) + } + + if stmt.Schema != nil { + for _, field := range stmt.Schema.FieldsByName { + name := field.DBName + if name == "" { + name = field.Name + } + + if requireCreate && !field.Creatable { + results[name] = false + } else if requireUpdate && !field.Updatable { + results[name] = false + } + } + } + + return results, !notRestricted && len(stmt.Selects) > 0 +} diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 00000000..648bc875 --- /dev/null +++ b/statement_test.go @@ -0,0 +1,64 @@ +package gorm + +import ( + "fmt" + "reflect" + "testing" + + "gorm.io/gorm/clause" +) + +func TestWhereCloneCorruption(t *testing.T) { + for whereCount := 1; whereCount <= 8; whereCount++ { + t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) { + s := new(Statement) + for w := 0; w < whereCount; w++ { + s = s.clone() + s.AddClause(clause.Where{ + Exprs: s.BuildCondition(fmt.Sprintf("where%d", w)), + }) + } + + s1 := s.clone() + s1.AddClause(clause.Where{ + Exprs: s.BuildCondition("FINAL1"), + }) + s2 := s.clone() + s2.AddClause(clause.Where{ + Exprs: s.BuildCondition("FINAL2"), + }) + + if reflect.DeepEqual(s1.Clauses["WHERE"], s2.Clauses["WHERE"]) { + t.Errorf("Where conditions should be different") + } + }) + } +} + +func TestNilCondition(t *testing.T) { + s := new(Statement) + if len(s.BuildCondition(nil)) != 0 { + t.Errorf("Nil condition should be empty") + } +} + +func TestNameMatcher(t *testing.T) { + for k, v := range map[string][]string{ + "table.name": {"table", "name"}, + "`table`.`name`": {"table", "name"}, + "'table'.'name'": {"table", "name"}, + "'table'.name": {"table", "name"}, + "table1.name_23": {"table1", "name_23"}, + "`table_1`.`name23`": {"table_1", "name23"}, + "'table23'.'name_1'": {"table23", "name_1"}, + "'table23'.name1": {"table23", "name1"}, + "'name1'": {"", "name1"}, + "`name_1`": {"", "name_1"}, + "`Name_1`": {"", "Name_1"}, + "`Table`.`nAme`": {"Table", "nAme"}, + } { + if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] { + t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) + } + } +} diff --git a/test_all.sh b/test_all.sh deleted file mode 100755 index 6c5593b3..00000000 --- a/test_all.sh +++ /dev/null @@ -1,5 +0,0 @@ -dialects=("postgres" "mysql" "sqlite") - -for dialect in "${dialects[@]}" ; do - GORM_DIALECT=${dialect} go test -done diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..08cb523c --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +go.sum diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..6ae3337f --- /dev/null +++ b/tests/README.md @@ -0,0 +1,10 @@ +# Test Guide + +```bash +cd tests +# prepare test databases +docker-compose up + +# run all tests +./tests_all.sh +``` diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go new file mode 100644 index 00000000..6befb5f2 --- /dev/null +++ b/tests/associations_belongs_to_test.go @@ -0,0 +1,308 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestBelongsToAssociation(t *testing.T) { + user := *GetUser("belongs-to", Config{Company: true, Manager: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + pointerOfUser := &user2 + if err := DB.Model(&pointerOfUser).Association("Company").Find(&user2.Company); err != nil { + t.Errorf("failed to query users, got error %#v", err) + } + user2.Manager = &User{} + DB.Model(&user2).Association("Manager").Find(user2.Manager) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Company", 1, "") + AssertAssociationCount(t, user, "Manager", 1, "") + + // Append + company := Company{Name: "company-belongs-to-append"} + manager := GetUser("manager-belongs-to-append", Config{}) + + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if company.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + if manager.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company + user.Manager = manager + user.CompanyID = &company.ID + user.ManagerID = &manager.ID + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Company", 1, "AfterAppend") + AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") + + // Replace + company2 := Company{Name: "company-belongs-to-replace"} + manager2 := GetUser("manager-belongs-to-replace", Config{}) + + if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { + t.Fatalf("Error happened when replace Company, got %v", err) + } + + if company2.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil { + t.Fatalf("Error happened when replace Manager, got %v", err) + } + + if manager2.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company2 + user.Manager = manager2 + user.CompanyID = &company2.ID + user.ManagerID = &manager2.ID + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Company", 1, "AfterReplace") + AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + AssertAssociationCount(t, user2, "Company", 0, "after delete") + + if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + AssertAssociationCount(t, user2, "Manager", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + AssertAssociationCount(t, user2, "Company", 1, "after prepare data") + AssertAssociationCount(t, user2, "Manager", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Company").Clear(); err != nil { + t.Errorf("Error happened when clear Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Clear(); err != nil { + t.Errorf("Error happened when clear Manager, got %v", err) + } + + AssertAssociationCount(t, user2, "Company", 0, "after clear") + AssertAssociationCount(t, user2, "Manager", 0, "after clear") + + // unexist company id + unexistCompanyID := company.ID + 9999999 + user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} + if err := DB.Create(&user).Error; err == nil { + tidbSkip(t, "not support the foreign key feature") + t.Errorf("should have gotten foreign key violation error") + } +} + +func TestBelongsToAssociationForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), + *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), + *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), + } + + DB.Create(&users) + + AssertAssociationCount(t, users, "Company", 3, "") + AssertAssociationCount(t, users, "Manager", 2, "") + + // Find + var companies []Company + if DB.Model(&users).Association("Company").Find(&companies); len(companies) != 3 { + t.Errorf("companies count should be %v, but got %v", 3, len(companies)) + } + + var managers []User + if DB.Model(&users).Association("Manager").Find(&managers); len(managers) != 2 { + t.Errorf("managers count should be %v, but got %v", 2, len(managers)) + } + + // Append + DB.Model(&users).Association("Company").Append( + &Company{Name: "company-slice-append-1"}, + &Company{Name: "company-slice-append-2"}, + &Company{Name: "company-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Company", 3, "After Append") + + DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-1", Config{}), + GetUser("manager-slice-belongs-to-2", Config{}), + GetUser("manager-slice-belongs-to-3", Config{}), + ) + AssertAssociationCount(t, users, "Manager", 3, "After Append") + + if err := DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-test-1", Config{}), + ).Error; err == nil { + t.Errorf("unmatched length when update user's manager") + } + + // Replace -> same as append + + // Delete + if err := DB.Model(&users).Association("Company").Delete(&users[0].Company); err != nil { + t.Errorf("no error should happened when deleting company, but got %v", err) + } + + if users[0].CompanyID != nil || users[0].Company.ID != 0 { + t.Errorf("users[0]'s company should be deleted'") + } + + AssertAssociationCount(t, users, "Company", 2, "After Delete") + + // Clear + DB.Model(&users).Association("Company").Clear() + AssertAssociationCount(t, users, "Company", 0, "After Clear") + + DB.Model(&users).Association("Manager").Clear() + AssertAssociationCount(t, users, "Manager", 0, "After Clear") + + // shared company + company := Company{Name: "shared"} + if err := DB.Model(&users[0]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if err := DB.Model(&users[1]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if users[0].CompanyID == nil || users[1].CompanyID == nil || *users[0].CompanyID != *users[1].CompanyID { + t.Errorf("user's company id should exists and equal, but its: %v, %v", users[0].CompanyID, users[1].CompanyID) + } + + DB.Model(&users[0]).Association("Company").Delete(&company) + AssertAssociationCount(t, users[0], "Company", 0, "After Delete") + AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") +} + +func TestBelongsToDefaultValue(t *testing.T) { + type Org struct { + ID string + } + type BelongsToUser struct { + OrgID string + Org Org `gorm:"default:NULL"` + } + + tx := DB.Session(&gorm.Session{}) + tx.Config.DisableForeignKeyConstraintWhenMigrating = true + AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false) + + tx.Migrator().DropTable(&BelongsToUser{}, &Org{}) + tx.AutoMigrate(&BelongsToUser{}, &Org{}) + + user := &BelongsToUser{ + Org: Org{ + ID: "BelongsToUser_Org_1", + }, + } + err := DB.Create(&user).Error + AssertEqual(t, err, nil) +} + +func TestBelongsToAssociationUnscoped(t *testing.T) { + type ItemParent struct { + gorm.Model + Logo string `gorm:"not null;type:varchar(50)"` + } + type ItemChild struct { + gorm.Model + Name string `gorm:"type:varchar(50)"` + ItemParentID uint + ItemParent ItemParent + } + + tx := DB.Session(&gorm.Session{}) + tx.Migrator().DropTable(&ItemParent{}, &ItemChild{}) + tx.AutoMigrate(&ItemParent{}, &ItemChild{}) + + item := ItemChild{ + Name: "name", + ItemParent: ItemParent{ + Logo: "logo", + }, + } + if err := tx.Create(&item).Error; err != nil { + t.Fatalf("failed to create items, got error: %v", err) + } + + tx = tx.Debug() + + // test replace + if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ + Logo: "updated logo", + }); err != nil { + t.Errorf("failed to replace item parent, got error: %v", err) + } + + var parents []ItemParent + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 1 { + t.Errorf("expected %d parents, got %d", 1, len(parents)) + } + + // test delete + if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil { + t.Errorf("failed to delete item parent, got error: %v", err) + } + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 0 { + t.Errorf("expected %d parents, got %d", 0, len(parents)) + } +} diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go new file mode 100644 index 00000000..c31c4b40 --- /dev/null +++ b/tests/associations_has_many_test.go @@ -0,0 +1,547 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestHasManyAssociation(t *testing.T) { + user := *GetUser("hasmany", Config{Pets: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Pets").Find(&user2.Pets) + CheckUser(t, user2, user) + + var pets []Pet + DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets) + + if len(pets) != 1 { + t.Fatalf("should only find one pets, but got %v", len(pets)) + } + + CheckPet(t, pets[0], *user.Pets[0]) + + if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 { + t.Fatalf("should only find one pets, but got %v", count) + } + + if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 { + t.Fatalf("should only find no pet with invalid conditions, but got %v", count) + } + + // Count + AssertAssociationCount(t, user, "Pets", 2, "") + + // Append + pet := Pet{Name: "pet-has-many-append"} + + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") + + pets2 := []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + + if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + for _, pet := range pets2 { + pet := pet + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") + + // Replace + pet2 := Pet{Name: "pet-has-many-replace"} + + if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + if pet2.ID == 0 { + t.Fatalf("pet2's ID should be created") + } + + user.Pets = []*Pet{&pet2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Pets", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Pets").Delete(&Pet{}); err != nil { + t.Fatalf("Error happened when delete pet, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Pets").Delete(&pet2); err != nil { + t.Fatalf("Error happened when delete Pets, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Pets").Clear(); err != nil { + t.Errorf("Error happened when clear Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 0, "after clear") +} + +func TestSingleTableHasManyAssociation(t *testing.T) { + user := *GetUser("hasmany", Config{Team: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Team").Find(&user2.Team) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Team", 2, "") + + // Append + team := *GetUser("team", Config{}) + + if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if team.ID == 0 { + t.Fatalf("Team's ID should be created") + } + + user.Team = append(user.Team, team) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Team", 3, "AfterAppend") + + teams := []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} + + if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil { + t.Fatalf("Error happened when append team, got %v", err) + } + + for _, team := range teams { + team := team + if team.ID == 0 { + t.Fatalf("Team's ID should be created") + } + + user.Team = append(user.Team, team) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice") + + // Replace + team2 := *GetUser("team-replace", Config{}) + + if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil { + t.Fatalf("Error happened when append team, got %v", err) + } + + if team2.ID == 0 { + t.Fatalf("team2's ID should be created") + } + + user.Team = []User{team2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Team", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Team").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete team, got %v", err) + } + AssertAssociationCount(t, user2, "Team", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Team").Delete(&team2); err != nil { + t.Fatalf("Error happened when delete Team, got %v", err) + } + AssertAssociationCount(t, user2, "Team", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { + t.Fatalf("Error happened when append Team, got %v", err) + } + + AssertAssociationCount(t, user2, "Team", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Team").Clear(); err != nil { + t.Errorf("Error happened when clear Team, got %v", err) + } + + AssertAssociationCount(t, user2, "Team", 0, "after clear") +} + +func TestHasManyAssociationForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-hasmany-1", Config{Pets: 2}), + *GetUser("slice-hasmany-2", Config{Pets: 0}), + *GetUser("slice-hasmany-3", Config{Pets: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Pets", 6, "") + + // Find + var pets []Pet + if DB.Model(&users).Association("Pets").Find(&pets); len(pets) != 6 { + t.Errorf("pets count should be %v, but got %v", 6, len(pets)) + } + + // Append + DB.Model(&users).Association("Pets").Append( + &Pet{Name: "pet-slice-append-1"}, + []*Pet{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, + &Pet{Name: "pet-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Pets").Replace( + []*Pet{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, + []*Pet{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, + &Pet{Name: "pet-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Pets").Delete(&users[2].Pets); err != nil { + t.Errorf("no error should happened when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 4, "after delete") + + if err := DB.Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { + t.Errorf("no error should happened when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 2, "after delete") + + // Clear + DB.Model(&users).Association("Pets").Clear() + AssertAssociationCount(t, users, "Pets", 0, "After Clear") +} + +func TestSingleTableHasManyAssociationForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-hasmany-1", Config{Team: 2}), + *GetUser("slice-hasmany-2", Config{Team: 0}), + *GetUser("slice-hasmany-3", Config{Team: 4}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + // Count + AssertAssociationCount(t, users, "Team", 6, "") + + // Find + var teams []User + if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { + t.Errorf("teams count should be %v, but got %v", 6, len(teams)) + } + + // Append + DB.Model(&users).Association("Team").Append( + &User{Name: "pet-slice-append-1"}, + []*User{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, + &User{Name: "pet-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Team", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Team").Replace( + []*User{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, + []*User{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, + &User{Name: "pet-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Team", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { + t.Errorf("no error should happened when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 4, "after delete") + + if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { + t.Errorf("no error should happened when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 2, "after delete") + + // Clear + DB.Model(&users).Association("Team").Clear() + AssertAssociationCount(t, users, "Team", 0, "After Clear") +} + +func TestPolymorphicHasManyAssociation(t *testing.T) { + user := *GetUser("hasmany", Config{Toys: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Toys").Find(&user2.Toys) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Toys", 2, "") + + // Append + toy := Toy{Name: "toy-has-many-append"} + + if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + user.Toys = append(user.Toys, toy) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Toys", 3, "AfterAppend") + + toys := []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} + + if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + for _, toy := range toys { + toy := toy + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + user.Toys = append(user.Toys, toy) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice") + + // Replace + toy2 := Toy{Name: "toy-has-many-replace"} + + if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + if toy2.ID == 0 { + t.Fatalf("toy2's ID should be created") + } + + user.Toys = []Toy{toy2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Toys", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Toys").Delete(&Toy{}); err != nil { + t.Fatalf("Error happened when delete toy, got %v", err) + } + AssertAssociationCount(t, user2, "Toys", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Toys").Delete(&toy2); err != nil { + t.Fatalf("Error happened when delete Toys, got %v", err) + } + AssertAssociationCount(t, user2, "Toys", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { + t.Fatalf("Error happened when append Toys, got %v", err) + } + + AssertAssociationCount(t, user2, "Toys", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Toys").Clear(); err != nil { + t.Errorf("Error happened when clear Toys, got %v", err) + } + + AssertAssociationCount(t, user2, "Toys", 0, "after clear") +} + +func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-hasmany-1", Config{Toys: 2}), + *GetUser("slice-hasmany-2", Config{Toys: 0}), + *GetUser("slice-hasmany-3", Config{Toys: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Toys", 6, "") + + // Find + var toys []Toy + if DB.Model(&users).Association("Toys").Find(&toys); len(toys) != 6 { + t.Errorf("toys count should be %v, but got %v", 6, len(toys)) + } + + // Append + DB.Model(&users).Association("Toys").Append( + &Toy{Name: "toy-slice-append-1"}, + []Toy{{Name: "toy-slice-append-2-1"}, {Name: "toy-slice-append-2-2"}}, + &Toy{Name: "toy-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Toys", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Toys").Replace( + []*Toy{{Name: "toy-slice-replace-1-1"}, {Name: "toy-slice-replace-1-2"}}, + []*Toy{{Name: "toy-slice-replace-2-1"}, {Name: "toy-slice-replace-2-2"}}, + &Toy{Name: "toy-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Toys", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Toys").Delete(&users[2].Toys); err != nil { + t.Errorf("no error should happened when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, users, "Toys", 4, "after delete") + + if err := DB.Model(&users).Association("Toys").Delete(users[0].Toys[0], users[1].Toys[1]); err != nil { + t.Errorf("no error should happened when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, users, "Toys", 2, "after delete") + + // Clear + DB.Model(&users).Association("Toys").Clear() + AssertAssociationCount(t, users, "Toys", 0, "After Clear") +} + +func TestHasManyAssociationUnscoped(t *testing.T) { + type ItemContent struct { + gorm.Model + ItemID uint `gorm:"not null"` + Name string `gorm:"not null;type:varchar(50)"` + LanguageCode string `gorm:"not null;type:varchar(2)"` + } + type Item struct { + gorm.Model + Logo string `gorm:"not null;type:varchar(50)"` + Contents []ItemContent `gorm:"foreignKey:ItemID"` + } + + tx := DB.Session(&gorm.Session{}) + tx.Migrator().DropTable(&ItemContent{}, &Item{}) + tx.AutoMigrate(&ItemContent{}, &Item{}) + + item := Item{ + Logo: "logo", + Contents: []ItemContent{ + {Name: "name", LanguageCode: "en"}, + {Name: "ar name", LanguageCode: "ar"}, + }, + } + if err := tx.Create(&item).Error; err != nil { + t.Fatalf("failed to create items, got error: %v", err) + } + + // test Replace + if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{ + {Name: "updated name", LanguageCode: "en"}, + {Name: "ar updated name", LanguageCode: "ar"}, + {Name: "le nom", LanguageCode: "fr"}, + }); err != nil { + t.Errorf("failed to replace item content, got error: %v", err) + } + + if count := tx.Model(&item).Association("Contents").Count(); count != 3 { + t.Errorf("expected %d contents, got %d", 3, count) + } + + var contents []ItemContent + if err := tx.Find(&contents).Error; err != nil { + t.Errorf("failed to find contents, got error: %v", err) + } + if len(contents) != 3 { + t.Errorf("expected %d contents, got %d", 3, len(contents)) + } + + // test delete + if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil { + t.Errorf("failed to delete Contents, got error: %v", err) + } + if count := tx.Model(&item).Association("Contents").Count(); count != 2 { + t.Errorf("expected %d contents, got %d", 2, count) + } + + // test clear + if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil { + t.Errorf("failed to clear contents association, got error: %v", err) + } + if count := tx.Model(&item).Association("Contents").Count(); count != 0 { + t.Errorf("expected %d contents, got %d", 0, count) + } + + if err := tx.Find(&contents).Error; err != nil { + t.Errorf("failed to find contents, got error: %v", err) + } + if len(contents) != 0 { + t.Errorf("expected %d contents, got %d", 0, len(contents)) + } +} diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go new file mode 100644 index 00000000..a2c07509 --- /dev/null +++ b/tests/associations_has_one_test.go @@ -0,0 +1,257 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestHasOneAssociation(t *testing.T) { + user := *GetUser("hasone", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Account").Find(&user2.Account) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Account", 1, "") + + // Append + account := Account{Number: "account-has-one-append"} + + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if account.ID == 0 { + t.Fatalf("Account's ID should be created") + } + + user.Account = account + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Account", 1, "AfterAppend") + + // Replace + account2 := Account{Number: "account-has-one-replace"} + + if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + if account2.ID == 0 { + t.Fatalf("account2's ID should be created") + } + + user.Account = account2 + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Account", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Account").Delete(&Account{}); err != nil { + t.Fatalf("Error happened when delete account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Account").Delete(&account2); err != nil { + t.Fatalf("Error happened when delete Account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 0, "after delete") + + // Prepare Data for Clear + account = Account{Number: "account-has-one-append"} + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Account").Clear(); err != nil { + t.Errorf("Error happened when clear Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 0, "after clear") +} + +func TestHasOneAssociationWithSelect(t *testing.T) { + user := *GetUser("hasone", Config{Account: true}) + + DB.Omit("Account.Number").Create(&user) + + AssertAssociationCount(t, user, "Account", 1, "") + + var account Account + DB.Model(&user).Association("Account").Find(&account) + if account.Number != "" { + t.Errorf("account's number should not be saved") + } +} + +func TestHasOneAssociationForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-hasone-1", Config{Account: true}), + *GetUser("slice-hasone-2", Config{Account: false}), + *GetUser("slice-hasone-3", Config{Account: true}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Account", 2, "") + + // Find + var accounts []Account + if DB.Model(&users).Association("Account").Find(&accounts); len(accounts) != 2 { + t.Errorf("accounts count should be %v, but got %v", 3, len(accounts)) + } + + // Append + DB.Model(&users).Association("Account").Append( + &Account{Number: "account-slice-append-1"}, + &Account{Number: "account-slice-append-2"}, + &Account{Number: "account-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Account", 3, "After Append") + + // Replace -> same as append + + // Delete + if err := DB.Model(&users).Association("Account").Delete(&users[0].Account); err != nil { + t.Errorf("no error should happened when deleting account, but got %v", err) + } + + AssertAssociationCount(t, users, "Account", 2, "after delete") + + // Clear + DB.Model(&users).Association("Account").Clear() + AssertAssociationCount(t, users, "Account", 0, "After Clear") +} + +func TestPolymorphicHasOneAssociation(t *testing.T) { + pet := Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckPet(t, pet, pet) + + // Find + var pet2 Pet + DB.Find(&pet2, "id = ?", pet.ID) + DB.Model(&pet2).Association("Toy").Find(&pet2.Toy) + CheckPet(t, pet2, pet) + + // Count + AssertAssociationCount(t, pet, "Toy", 1, "") + + // Append + toy := Toy{Name: "toy-has-one-append"} + + if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + pet.Toy = toy + CheckPet(t, pet2, pet) + + AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend") + + // Replace + toy2 := Toy{Name: "toy-has-one-replace"} + + if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil { + t.Fatalf("Error happened when append Toy, got %v", err) + } + + if toy2.ID == 0 { + t.Fatalf("toy2's ID should be created") + } + + pet.Toy = toy2 + CheckPet(t, pet2, pet) + + AssertAssociationCount(t, pet2, "Toy", 1, "AfterReplace") + + // Delete + if err := DB.Model(&pet2).Association("Toy").Delete(&Toy{}); err != nil { + t.Fatalf("Error happened when delete toy, got %v", err) + } + AssertAssociationCount(t, pet2, "Toy", 1, "after delete non-existing data") + + if err := DB.Model(&pet2).Association("Toy").Delete(&toy2); err != nil { + t.Fatalf("Error happened when delete Toy, got %v", err) + } + AssertAssociationCount(t, pet2, "Toy", 0, "after delete") + + // Prepare Data for Clear + toy = Toy{Name: "toy-has-one-append"} + if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { + t.Fatalf("Error happened when append Toy, got %v", err) + } + + AssertAssociationCount(t, pet2, "Toy", 1, "after prepare data") + + // Clear + if err := DB.Model(&pet2).Association("Toy").Clear(); err != nil { + t.Errorf("Error happened when clear Toy, got %v", err) + } + + AssertAssociationCount(t, pet2, "Toy", 0, "after clear") +} + +func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { + pets := []Pet{ + {Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}}, + {Name: "hasone-2", Toy: Toy{}}, + {Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}}, + } + + DB.Create(&pets) + + // Count + AssertAssociationCount(t, pets, "Toy", 2, "") + + // Find + var toys []Toy + if DB.Model(&pets).Association("Toy").Find(&toys); len(toys) != 2 { + t.Errorf("toys count should be %v, but got %v", 3, len(toys)) + } + + // Append + DB.Model(&pets).Association("Toy").Append( + &Toy{Name: "toy-slice-append-1"}, + &Toy{Name: "toy-slice-append-2"}, + &Toy{Name: "toy-slice-append-3"}, + ) + + AssertAssociationCount(t, pets, "Toy", 3, "After Append") + + // Replace -> same as append + + // Delete + if err := DB.Model(&pets).Association("Toy").Delete(&pets[0].Toy); err != nil { + t.Errorf("no error should happened when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, pets, "Toy", 2, "after delete") + + // Clear + DB.Model(&pets).Association("Toy").Clear() + AssertAssociationCount(t, pets, "Toy", 0, "After Clear") +} diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go new file mode 100644 index 00000000..b69d668a --- /dev/null +++ b/tests/associations_many2many_test.go @@ -0,0 +1,425 @@ +package tests_test + +import ( + "fmt" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestMany2ManyAssociation(t *testing.T) { + user := *GetUser("many2many", Config{Languages: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Languages").Find(&user2.Languages) + + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Languages", 2, "") + + // Append + language := Language{Code: "language-many2many-append", Name: "language-many2many-append"} + DB.Create(&language) + + if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + user.Languages = append(user.Languages, language) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") + + languages := []Language{ + {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, + {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, + } + DB.Create(&languages) + + if err := DB.Model(&user2).Association("Languages").Append(&languages); err != nil { + t.Fatalf("Error happened when append language, got %v", err) + } + + user.Languages = append(user.Languages, languages...) + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") + + // Replace + language2 := Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} + DB.Create(&language2) + + if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { + t.Fatalf("Error happened when append language, got %v", err) + } + + user.Languages = []Language{language2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Languages", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Languages").Delete(&Language{}); err != nil { + t.Fatalf("Error happened when delete language, got %v", err) + } + AssertAssociationCount(t, user2, "Languages", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Languages").Delete(&language2); err != nil { + t.Fatalf("Error happened when delete Languages, got %v", err) + } + AssertAssociationCount(t, user2, "Languages", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { + t.Fatalf("Error happened when append Languages, got %v", err) + } + + AssertAssociationCount(t, user2, "Languages", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Languages").Clear(); err != nil { + t.Errorf("Error happened when clear Languages, got %v", err) + } + + AssertAssociationCount(t, user2, "Languages", 0, "after clear") +} + +func TestMany2ManyOmitAssociations(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + + user := *GetUser("many2many_omit_associations", Config{Languages: 2}) + + if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { + t.Fatalf("should raise error when create users without languages reference") + } + + if err := DB.Create(&user.Languages).Error; err != nil { + t.Fatalf("no error should happen when create languages, but got %v", err) + } + + if err := DB.Omit("Languages.*").Create(&user).Error; err != nil { + t.Fatalf("no error should happen when create user when languages exists, but got %v", err) + } + + // Find + var languages []Language + if DB.Model(&user).Association("Languages").Find(&languages); len(languages) != 2 { + t.Errorf("languages count should be %v, but got %v", 2, len(languages)) + } + + newLang := Language{Code: "omitmany2many", Name: "omitmany2many"} + if err := DB.Model(&user).Omit("Languages.*").Association("Languages").Replace(&newLang); err == nil { + t.Errorf("should failed to insert languages due to constraint failed, error: %v", err) + } +} + +func TestMany2ManyAssociationForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-many2many-1", Config{Languages: 2}), + *GetUser("slice-many2many-2", Config{Languages: 0}), + *GetUser("slice-many2many-3", Config{Languages: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Languages", 6, "") + + // Find + var languages []Language + if DB.Model(&users).Association("Languages").Find(&languages); len(languages) != 6 { + t.Errorf("languages count should be %v, but got %v", 6, len(languages)) + } + + // Append + languages1 := []Language{ + {Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, + } + languages2 := []Language{} + languages3 := []Language{ + {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, + {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, + } + DB.Create(&languages1) + DB.Create(&languages3) + + DB.Model(&users).Association("Languages").Append(&languages1, &languages2, &languages3) + + AssertAssociationCount(t, users, "Languages", 9, "After Append") + + languages2_1 := []*Language{ + {Code: "language-slice-replace-1-1", Name: "language-slice-replace-1-1"}, + {Code: "language-slice-replace-1-2", Name: "language-slice-replace-1-2"}, + } + languages2_2 := []*Language{ + {Code: "language-slice-replace-2-1", Name: "language-slice-replace-2-1"}, + {Code: "language-slice-replace-2-2", Name: "language-slice-replace-2-2"}, + } + languages2_3 := &Language{Code: "language-slice-replace-3", Name: "language-slice-replace-3"} + DB.Create(&languages2_1) + DB.Create(&languages2_2) + DB.Create(&languages2_3) + + // Replace + DB.Model(&users).Association("Languages").Replace(&languages2_1, &languages2_2, languages2_3) + + AssertAssociationCount(t, users, "Languages", 5, "After Replace") + + // Delete + if err := DB.Model(&users).Association("Languages").Delete(&users[2].Languages); err != nil { + t.Errorf("no error should happened when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 4, "after delete") + + if err := DB.Model(&users).Association("Languages").Delete(users[0].Languages[0], users[1].Languages[1]); err != nil { + t.Errorf("no error should happened when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 2, "after delete") + + // Clear + DB.Model(&users).Association("Languages").Clear() + AssertAssociationCount(t, users, "Languages", 0, "After Clear") +} + +func TestSingleTableMany2ManyAssociation(t *testing.T) { + user := *GetUser("many2many", Config{Friends: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Friends").Find(&user2.Friends) + + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Friends", 2, "") + + // Append + friend := *GetUser("friend", Config{}) + + if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + user.Friends = append(user.Friends, &friend) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Friends", 3, "AfterAppend") + + friends := []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} + + if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil { + t.Fatalf("Error happened when append friend, got %v", err) + } + + user.Friends = append(user.Friends, friends...) + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice") + + // Replace + friend2 := *GetUser("friend-replace-2", Config{}) + + if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil { + t.Fatalf("Error happened when append friend, got %v", err) + } + + user.Friends = []*User{&friend2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Friends", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Friends").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete friend, got %v", err) + } + AssertAssociationCount(t, user2, "Friends", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Friends").Delete(&friend2); err != nil { + t.Fatalf("Error happened when delete Friends, got %v", err) + } + AssertAssociationCount(t, user2, "Friends", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { + t.Fatalf("Error happened when append Friends, got %v", err) + } + + AssertAssociationCount(t, user2, "Friends", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Friends").Clear(); err != nil { + t.Errorf("Error happened when clear Friends, got %v", err) + } + + AssertAssociationCount(t, user2, "Friends", 0, "after clear") +} + +func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-many2many-1", Config{Team: 2}), + *GetUser("slice-many2many-2", Config{Team: 0}), + *GetUser("slice-many2many-3", Config{Team: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Team", 6, "") + + // Find + var teams []User + if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { + t.Errorf("teams count should be %v, but got %v", 6, len(teams)) + } + + // Append + teams1 := []User{*GetUser("friend-append-1", Config{})} + teams2 := []User{} + teams3 := []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} + + DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3) + + AssertAssociationCount(t, users, "Team", 9, "After Append") + + teams2_1 := []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} + teams2_2 := []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} + teams2_3 := GetUser("friend-replace-3-1", Config{}) + + // Replace + DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3) + + AssertAssociationCount(t, users, "Team", 5, "After Replace") + + // Delete + if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { + t.Errorf("no error should happened when deleting team, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 4, "after delete") + + if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { + t.Errorf("no error should happened when deleting team, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 2, "after delete") + + // Clear + DB.Model(&users).Association("Team").Clear() + AssertAssociationCount(t, users, "Team", 0, "After Clear") +} + +func TestDuplicateMany2ManyAssociation(t *testing.T) { + user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-2"}, + }} + + user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-3"}, + }} + users := []*User{&user1, &user2} + var err error + err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error + AssertEqual(t, nil, err) + + var findUser1 User + err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error + AssertEqual(t, nil, err) + AssertEqual(t, user1, findUser1) + + var findUser2 User + err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error + AssertEqual(t, nil, err) + AssertEqual(t, user2, findUser2) +} + +func TestConcurrentMany2ManyAssociation(t *testing.T) { + db, err := OpenTestConnection() + if err != nil { + t.Fatalf("open test connection failed, err: %+v", err) + } + + count := 3 + + var languages []Language + for i := 0; i < count; i++ { + language := Language{Code: fmt.Sprintf("consurrent %d", i)} + db.Create(&language) + languages = append(languages, language) + } + + user := User{} + db.Create(&user) + db.Preload("Languages").FirstOrCreate(&user) + + var wg sync.WaitGroup + for i := 0; i < count; i++ { + wg.Add(1) + go func(user User, language Language) { + err := db.Model(&user).Association("Languages").Append(&language) + AssertEqual(t, err, nil) + + wg.Done() + }(user, languages[i]) + } + wg.Wait() + + var find User + err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error + AssertEqual(t, err, nil) + AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append") +} + +func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) { + user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{ + {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{ + ID: 1, + Name: "Test-company-1", + }}, + }} + + user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{ + {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{ + ID: 1, + Name: "Test-company-1", + }}, + }} + users := []*User{&user1, &user2} + var err error + err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error + AssertEqual(t, nil, err) + + var findUser1 User + err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error + AssertEqual(t, nil, err) + AssertEqual(t, user1, findUser1) + + var findUser2 User + err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error + AssertEqual(t, nil, err) + AssertEqual(t, user2, findUser2) +} diff --git a/tests/associations_test.go b/tests/associations_test.go new file mode 100644 index 00000000..4e8862e5 --- /dev/null +++ b/tests/associations_test.go @@ -0,0 +1,396 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + . "gorm.io/gorm/utils/tests" +) + +func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { + if count := DB.Model(data).Association(name).Count(); count != result { + t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + } + + var newUser User + if user, ok := data.(User); ok { + DB.Find(&newUser, "id = ?", user.ID) + } else if user, ok := data.(*User); ok { + DB.Find(&newUser, "id = ?", user.ID) + } + + if newUser.ID != 0 { + if count := DB.Model(&newUser).Association(name).Count(); count != result { + t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + } + } +} + +func TestInvalidAssociation(t *testing.T) { + user := *GetUser("invalid", Config{Company: true, Manager: true}) + if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { + t.Fatalf("should return errors for invalid association, but got nil") + } +} + +func TestAssociationNotNullClear(t *testing.T) { + type Profile struct { + gorm.Model + Number string + MemberID uint `gorm:"not null"` + } + + type Member struct { + gorm.Model + Profiles []Profile + } + + DB.Migrator().DropTable(&Member{}, &Profile{}) + + if err := DB.AutoMigrate(&Member{}, &Profile{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := &Member{ + Profiles: []Profile{{ + Number: "1", + }, { + Number: "2", + }}, + } + + if err := DB.Create(&member).Error; err != nil { + t.Fatalf("Failed to create test data, got error: %v", err) + } + + if err := DB.Model(member).Association("Profiles").Clear(); err == nil { + t.Fatalf("No error occurred during clearind not null association") + } +} + +func TestForeignKeyConstraints(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + + type Profile struct { + ID uint + Name string + MemberID uint + } + + type Member struct { + ID uint + Refer uint `gorm:"uniqueIndex"` + Name string + Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:MemberID;References:Refer"` + } + + DB.Migrator().DropTable(&Profile{}, &Member{}) + + if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := Member{Refer: 1, Name: "foreign_key_constraints", Profile: Profile{Name: "my_profile"}} + + DB.Create(&member) + + var profile Profile + if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile.MemberID != member.ID { + t.Fatalf("member id is not equal: expects: %v, got: %v", member.ID, profile.MemberID) + } + + member.Profile = Profile{} + DB.Model(&member).Update("Refer", 100) + + var profile2 Profile + if err := DB.First(&profile2, "id = ?", profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile2.MemberID != 100 { + t.Fatalf("member id is not equal: expects: %v, got: %v", 100, profile2.MemberID) + } + + if r := DB.Delete(&member); r.Error != nil || r.RowsAffected != 1 { + t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) + } + + var result Member + if err := DB.First(&result, member.ID).Error; err == nil { + t.Fatalf("Should not find deleted member") + } + + if err := DB.First(&profile2, profile.ID).Error; err == nil { + t.Fatalf("Should not find deleted profile") + } +} + +func TestForeignKeyConstraintsBelongsTo(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + + type Profile struct { + ID uint + Name string + Refer uint `gorm:"uniqueIndex"` + } + + type Member struct { + ID uint + Name string + ProfileID uint + Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:ProfileID;References:Refer"` + } + + DB.Migrator().DropTable(&Profile{}, &Member{}) + + if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := Member{Name: "foreign_key_constraints_belongs_to", Profile: Profile{Name: "my_profile_belongs_to", Refer: 1}} + + DB.Create(&member) + + var profile Profile + if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile.Refer != member.ProfileID { + t.Fatalf("member id is not equal: expects: %v, got: %v", profile.Refer, member.ProfileID) + } + + DB.Model(&profile).Update("Refer", 100) + + var member2 Member + if err := DB.First(&member2, "id = ?", member.ID).Error; err != nil { + t.Fatalf("failed to find member, got error: %v", err) + } else if member2.ProfileID != 100 { + t.Fatalf("member id is not equal: expects: %v, got: %v", 100, member2.ProfileID) + } + + if r := DB.Delete(&profile); r.Error != nil || r.RowsAffected != 1 { + t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) + } + + var result Member + if err := DB.First(&result, member.ID).Error; err == nil { + t.Fatalf("Should not find deleted member") + } + + if err := DB.First(&profile, profile.ID).Error; err == nil { + t.Fatalf("Should not find deleted profile") + } +} + +func TestFullSaveAssociations(t *testing.T) { + coupon := &Coupon{ + AppliesToProduct: []*CouponProduct{ + {ProductId: "full-save-association-product1"}, + }, + AmountOff: 10, + PercentOff: 0.0, + } + + err := DB. + Session(&gorm.Session{FullSaveAssociations: true}). + Create(coupon).Error + if err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if DB.First(&Coupon{}, "id = ?", coupon.ID).Error != nil { + t.Errorf("Failed to query saved coupon") + } + + if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", coupon.ID, "full-save-association-product1").Error != nil { + t.Errorf("Failed to query saved association") + } + + orders := []Order{{Num: "order1", Coupon: coupon}, {Num: "order2", Coupon: coupon}} + if err := DB.Create(&orders).Error; err != nil { + t.Errorf("failed to create orders, got %v", err) + } + + coupon2 := Coupon{ + AppliesToProduct: []*CouponProduct{{Desc: "coupon-description"}}, + } + + DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&coupon2) + var result Coupon + if err := DB.Preload("AppliesToProduct").First(&result, "id = ?", coupon2.ID).Error; err != nil { + t.Errorf("Failed to create coupon w/o name, got error: %v", err) + } + + if len(result.AppliesToProduct) != 1 { + t.Errorf("Failed to preload AppliesToProduct") + } +} + +func TestSaveBelongsCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent} + DB.Create(&child) + + parent.FavChildID = child.ID + parent.FavChild = &child + DB.Save(&parent) + + var parent1 Parent + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") + + // Save and Updates is the same + DB.Updates(&parent) + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") +} + +func TestSaveHasManyCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"} + child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"} + + parent.Children = []*Child{&child, &child1} + DB.Save(&parent) + + var children []*Child + DB.Where("parent_id = ?", parent.ID).Find(&children) + if len(children) != len(parent.Children) || + children[0].ID != parent.Children[0].ID || + children[1].ID != parent.Children[1].ID { + t.Errorf("circular reference children save not equal children:%v parent.Children:%v", + children, parent.Children) + } +} + +func TestAssociationError(t *testing.T) { + user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2}) + DB.Create(&user) + + var user1 User + DB.Preload("Company").Preload("Pets").Preload("Account").Preload("Languages").First(&user1) + + var emptyUser User + var err error + // belongs to + err = DB.Model(&emptyUser).Association("Company").Delete(&user1.Company) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // has many + err = DB.Model(&emptyUser).Association("Pets").Delete(&user1.Pets) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // has one + err = DB.Model(&emptyUser).Association("Account").Delete(&user1.Account) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // many to many + err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) +} + +type ( + myType string + emptyQueryClause struct { + Field *schema.Field + } +) + +func (myType) QueryClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{emptyQueryClause{Field: f}} +} + +func (sd emptyQueryClause) Name() string { + return "empty" +} + +func (sd emptyQueryClause) Build(clause.Builder) { +} + +func (sd emptyQueryClause) MergeClause(*clause.Clause) { +} + +func (sd emptyQueryClause) ModifyStatement(stmt *gorm.Statement) { + // do nothing +} + +func TestAssociationEmptyQueryClause(t *testing.T) { + type Organization struct { + gorm.Model + Name string + } + type Region struct { + gorm.Model + Name string + Organizations []Organization `gorm:"many2many:region_orgs;"` + } + type RegionOrg struct { + RegionId uint + OrganizationId uint + Empty myType + } + if err := DB.SetupJoinTable(&Region{}, "Organizations", &RegionOrg{}); err != nil { + t.Fatalf("Failed to set up join table, got error: %s", err) + } + if err := DB.Migrator().DropTable(&Organization{}, &Region{}); err != nil { + t.Fatalf("Failed to migrate, got error: %s", err) + } + if err := DB.AutoMigrate(&Organization{}, &Region{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + region := &Region{Name: "Region1"} + if err := DB.Create(region).Error; err != nil { + t.Fatalf("fail to create region %v", err) + } + var orgs []Organization + + if err := DB.Model(&Region{}).Association("Organizations").Find(&orgs); err != nil { + t.Fatalf("fail to find region organizations %v", err) + } else { + AssertEqual(t, len(orgs), 0) + } +} + +type AssociationEmptyUser struct { + ID uint + Name string + Pets []AssociationEmptyPet +} + +type AssociationEmptyPet struct { + AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"` + Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"` +} + +func TestAssociationEmptyPrimaryKey(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + + id := uint(100) + user := AssociationEmptyUser{ + ID: id, + Name: "jinzhu", + Pets: []AssociationEmptyPet{ + {AssociationEmptyUserID: &id, Name: "bar"}, + {AssociationEmptyUserID: &id, Name: "foo"}, + }, + } + + err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error + if err != nil { + t.Fatalf("Failed to create, got error: %v", err) + } + + var result AssociationEmptyUser + err = DB.Preload("Pets").First(&result, &id).Error + if err != nil { + t.Fatalf("Failed to find, got error: %v", err) + } + + AssertEqual(t, result, user) +} diff --git a/tests/benchmark_test.go b/tests/benchmark_test.go new file mode 100644 index 00000000..22d15898 --- /dev/null +++ b/tests/benchmark_test.go @@ -0,0 +1,84 @@ +package tests_test + +import ( + "fmt" + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func BenchmarkCreate(b *testing.B) { + user := *GetUser("bench", Config{}) + + for x := 0; x < b.N; x++ { + user.ID = 0 + DB.Create(&user) + } +} + +func BenchmarkFind(b *testing.B) { + user := *GetUser("find", Config{}) + DB.Create(&user) + + for x := 0; x < b.N; x++ { + DB.Find(&User{}, "id = ?", user.ID) + } +} + +func BenchmarkScan(b *testing.B) { + user := *GetUser("scan", Config{}) + DB.Create(&user) + + var u User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users where id = ?", user.ID).Scan(&u) + } +} + +func BenchmarkScanSlice(b *testing.B) { + DB.Exec("delete from users") + for i := 0; i < 10_000; i++ { + user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) + DB.Create(&user) + } + + var u []User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users").Scan(&u) + } +} + +func BenchmarkScanSlicePointer(b *testing.B) { + DB.Exec("delete from users") + for i := 0; i < 10_000; i++ { + user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) + DB.Create(&user) + } + + var u []*User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users").Scan(&u) + } +} + +func BenchmarkUpdate(b *testing.B) { + user := *GetUser("find", Config{}) + DB.Create(&user) + + for x := 0; x < b.N; x++ { + DB.Model(&user).Updates(map[string]interface{}{"Age": x}) + } +} + +func BenchmarkDelete(b *testing.B) { + user := *GetUser("find", Config{}) + + for x := 0; x < b.N; x++ { + user.ID = 0 + DB.Create(&user) + DB.Delete(&user) + } +} diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go new file mode 100644 index 00000000..4479da4c --- /dev/null +++ b/tests/callbacks_test.go @@ -0,0 +1,208 @@ +package tests_test + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "testing" + + "gorm.io/gorm" +) + +func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) { + var ( + got []string + funcs = reflect.ValueOf(v).Elem().FieldByName("fns") + ) + + for i := 0; i < funcs.Len(); i++ { + got = append(got, getFuncName(funcs.Index(i))) + } + + return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) +} + +func getFuncName(fc interface{}) string { + reflectValue, ok := fc.(reflect.Value) + if !ok { + reflectValue = reflect.ValueOf(fc) + } + + fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".") + return fnames[len(fnames)-1] +} + +func c1(*gorm.DB) {} +func c2(*gorm.DB) {} +func c3(*gorm.DB) {} +func c4(*gorm.DB) {} +func c5(*gorm.DB) {} +func c6(*gorm.DB) {} + +func TestCallbacks(t *testing.T) { + type callback struct { + name string + before string + after string + remove bool + replace bool + err string + match func(*gorm.DB) bool + h func(*gorm.DB) + } + + datas := []struct { + callbacks []callback + err string + results []string + }{ + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c4", "c5"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c5", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c3", "c1", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c3", "c1", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + err: "conflicting", + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, + results: []string{"c1", "c5", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, + results: []string{"c1", "c4", "c3"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c3", "c5", "c1", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, + }, + } + + for idx, data := range datas { + db, err := gorm.Open(nil, nil) + if err != nil { + t.Fatal(err) + } + callbacks := db.Callback() + + for _, c := range data.callbacks { + var v interface{} = callbacks.Create() + callMethod := func(s interface{}, name string, args ...interface{}) { + var argValues []reflect.Value + for _, arg := range args { + argValues = append(argValues, reflect.ValueOf(arg)) + } + + results := reflect.ValueOf(s).MethodByName(name).Call(argValues) + if len(results) > 0 { + v = results[0].Interface() + } + } + + if c.name == "" { + c.name = getFuncName(c.h) + } + + if c.before != "" { + callMethod(v, "Before", c.before) + } + + if c.after != "" { + callMethod(v, "After", c.after) + } + + if c.match != nil { + callMethod(v, "Match", c.match) + } + + if c.remove { + callMethod(v, "Remove", c.name) + } else if c.replace { + callMethod(v, "Replace", c.name, c.h) + } else { + callMethod(v, "Register", c.name, c.h) + } + + if e, ok := v.(error); !ok || e != nil { + err = e + } + } + + if len(data.err) > 0 && err == nil { + t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err) + } else if len(data.err) == 0 && err != nil { + t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err) + } + + if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok { + t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) + } + } +} + +func TestPluginCallbacks(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("plugin_1_fn1", c1) + createCallback.After("*").Register("plugin_1_fn2", c2) + + if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 2 + createCallback.Before("*").Register("plugin_2_fn1", c3) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_2_fn2", c4) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 3 + createCallback.Before("*").Register("plugin_3_fn1", c5) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_3_fn2", c6) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +} diff --git a/tests/connection_test.go b/tests/connection_test.go new file mode 100644 index 00000000..7bc23009 --- /dev/null +++ b/tests/connection_test.go @@ -0,0 +1,46 @@ +package tests_test + +import ( + "fmt" + "testing" + + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func TestWithSingleConnection(t *testing.T) { + expectedName := "test" + var actualName string + + setSQL, getSQL := getSetSQL(DB.Dialector.Name()) + if len(setSQL) == 0 || len(getSQL) == 0 { + return + } + + err := DB.Connection(func(tx *gorm.DB) error { + if err := tx.Exec(setSQL, expectedName).Error; err != nil { + return err + } + + if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil { + return err + } + return nil + }) + if err != nil { + t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err)) + } + + if actualName != expectedName { + t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName) + } +} + +func getSetSQL(driverName string) (string, string) { + switch driverName { + case mysql.Dialector{}.Name(): + return "SET @testName := ?", "SELECT @testName" + default: + return "", "" + } +} diff --git a/tests/connpool_test.go b/tests/connpool_test.go new file mode 100644 index 00000000..e0e1c771 --- /dev/null +++ b/tests/connpool_test.go @@ -0,0 +1,173 @@ +package tests_test + +import ( + "context" + "database/sql" + "os" + "reflect" + "testing" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type wrapperTx struct { + *sql.Tx + conn *wrapperConnPool +} + +func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.PrepareContext(ctx, query) +} + +func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.ExecContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryRowContext(ctx, query, args...) +} + +type wrapperConnPool struct { + db *sql.DB + got []string + expect []string +} + +func (c *wrapperConnPool) Ping() error { + return c.db.Ping() +} + +// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. +// +// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { +// return c.db.BeginTx(ctx, opts) +// } +// +// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { + tx, err := c.db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &wrapperTx{Tx: tx, conn: c}, nil +} + +func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.got = append(c.got, query) + return c.db.PrepareContext(ctx, query) +} + +func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.got = append(c.got, query) + return c.db.ExecContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.got = append(c.got, query) + return c.db.QueryContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.got = append(c.got, query) + return c.db.QueryRowContext(ctx, query, args...) +} + +func TestConnPoolWrapper(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect != "mysql" { + t.SkipNow() + } + + dbDSN := os.Getenv("GORM_DSN") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + nativeDB, err := sql.Open("mysql", dbDSN) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + conn := &wrapperConnPool{ + db: nativeDB, + expect: []string{ + "SELECT VERSION()", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + }, + } + + defer func() { + if !reflect.DeepEqual(conn.got, conn.expect) { + t.Errorf("expect %#v but got %#v", conn.expect, conn.got) + } + }() + + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + tx := db.Begin() + user := *GetUser("transaction", Config{}) + + if err = tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + user1 := *GetUser("transaction1-1", Config{}) + + if err = tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { + t.Fatalf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil { + t.Fatalf("Should not find record after rollback, but got %v", err) + } + + txDB := db.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() + user2 := *GetUser("transaction-2", Config{}) + if err = tx2.Save(&user2).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should be able to find committed record, but got %v", err) + } +} diff --git a/tests/count_test.go b/tests/count_test.go new file mode 100644 index 00000000..b0dfb0b5 --- /dev/null +++ b/tests/count_test.go @@ -0,0 +1,191 @@ +package tests_test + +import ( + "fmt" + "regexp" + "sort" + "strings" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestCountWithGroup(t *testing.T) { + DB.Create([]Company{ + {Name: "company_count_group_a"}, + {Name: "company_count_group_a"}, + {Name: "company_count_group_a"}, + {Name: "company_count_group_b"}, + {Name: "company_count_group_c"}, + }) + + var count1 int64 + if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + if count1 != 1 { + t.Errorf("Count with group should be 1, but got count: %v", count1) + } + + var count2 int64 + if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + if count2 != 2 { + t.Errorf("Count with group should be 2, but got count: %v", count2) + } +} + +func TestCount(t *testing.T) { + var ( + user1 = *GetUser("count-1", Config{}) + user2 = *GetUser("count-2", Config{}) + user3 = *GetUser("count-3", Config{}) + users []User + count, count1, count2 int64 + ) + + DB.Save(&user1).Save(&user2).Save(&user3) + + if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) + } + + if err := DB.Model(&User{}).Where("name = ?", user1.Name).Or("name = ?", user3.Name).Count(&count).Find(&users).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) + } + + DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) + if count1 != 1 || count2 != 3 { + t.Errorf("multiple count in chain should works") + } + + tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{}) + tx.Count(&count1) + tx.Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) + if count1 != 1 || count2 != 3 { + t.Errorf("count after new session should works") + } + + var count3 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { + t.Errorf("Error happened when count with group, but got %v", err) + } + + if count3 != 2 { + t.Errorf("Should get correct count for count with group, but got %v", count3) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + result := dryDB.Table("users").Select("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(.name.\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Table("users").Distinct("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } + + var count4 int64 + if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count4) + } + + var count5 int64 + if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count) + } + + var count6 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other", + ).Count(&count6).Find(&users).Error; err != nil || count6 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects := []User{{Name: "main"}, {Name: "other"}, {Name: "other"}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count7 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other", + ).Count(&count7).Find(&users).Error; err != nil || count7 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects = []User{{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count8 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name", + ).Count(&count8).Find(&users).Error; err != nil || count8 != 3 { + t.Fatalf("Count should work, but got err %v", err) + } + + expects = []User{{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count9 int64 + if err := DB.Scopes(func(tx *gorm.DB) *gorm.DB { + return tx.Table("users") + }).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { + t.Fatalf("Count should work, but got err %v", err) + } + + var count10 int64 + if err := DB.Model(&User{}).Select("*").Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count10).Error; err != nil || count10 != 3 { + t.Fatalf("Count should be 3, but got count: %v err %v", count10, err) + } + + var count11 int64 + sameUsers := make([]*User, 0) + for i := 0; i < 3; i++ { + sameUsers = append(sameUsers, GetUser("count-4", Config{})) + } + DB.Create(sameUsers) + + if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { + t.Fatalf("Count should be 1, but got count: %v err %v", count11, err) + } + + var count12 int64 + if err := DB.Table("users"). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count12).Error; err == nil { + t.Errorf("error should raise when using preload without schema") + } + + var count13 int64 + if err := DB.Model(User{}). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count13).Error; err != nil { + t.Errorf("no error should raise when using count with preload, but got %v", err) + } +} diff --git a/tests/create_test.go b/tests/create_test.go new file mode 100644 index 00000000..02613b72 --- /dev/null +++ b/tests/create_test.go @@ -0,0 +1,617 @@ +package tests_test + +import ( + "errors" + "regexp" + "testing" + "time" + + "github.com/jinzhu/now" + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestCreate(t *testing.T) { + user := *GetUser("create", Config{}) + + if results := DB.Create(&user); results.Error != nil { + t.Fatalf("errors happened when create: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + if user.ID == 0 { + t.Errorf("user's primary key should has value after create, got : %v", user.ID) + } + + if user.CreatedAt.IsZero() { + t.Errorf("user's created at should be not zero") + } + + if user.UpdatedAt.IsZero() { + t.Errorf("user's updated at should be not zero") + } + + var newUser User + if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } +} + +func TestCreateInBatches(t *testing.T) { + users := []User{ + *GetUser("create_in_batches_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("create_in_batches_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("create_in_batches_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("create_in_batches_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("create_in_batches_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + } + + result := DB.CreateInBatches(&users, 2) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("failed to fill user's ID, got %v", user.ID) + } else { + var newUser User + if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } + } + } +} + +func TestCreateInBatchesWithDefaultSize(t *testing.T) { + users := []User{ + *GetUser("create_with_default_batch_size_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("create_with_default_batch_sizs_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("create_with_default_batch_sizs_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("create_with_default_batch_sizs_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("create_with_default_batch_sizs_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("create_with_default_batch_sizs_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + } + + result := DB.Session(&gorm.Session{CreateBatchSize: 2}).Create(&users) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("failed to fill user's ID, got %v", user.ID) + } else { + var newUser User + if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } + } + } +} + +func TestCreateFromMap(t *testing.T) { + if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result User + if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + datas := []map[string]interface{}{ + {"Name": "create_from_map_2", "Age": 19}, + {"name": "create_from_map_3", "Age": 20}, + } + + if err := DB.Model(&User{}).Create(&datas).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var result3 User + if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } +} + +func TestCreateWithAssociations(t *testing.T) { + user := *GetUser("create_with_associations", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + var user2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} + +func TestBulkCreateWithAssociations(t *testing.T) { + users := []User{ + *GetUser("bulk_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("bulk_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("bulk_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("bulk_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("bulk_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("bulk_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + *GetUser("bulk_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}), + *GetUser("bulk_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), + } + + if results := DB.Create(&users); results.Error != nil { + t.Fatalf("errors happened when create: %v", results.Error) + } else if results.RowsAffected != int64(len(users)) { + t.Fatalf("rows affected expects: %v, got %v", len(users), results.RowsAffected) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + CheckUser(t, user, user) + } + + var users2 []User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs) + for idx, user := range users2 { + CheckUser(t, user, users[idx]) + } +} + +func TestBulkCreatePtrDataWithAssociations(t *testing.T) { + users := []*User{ + GetUser("bulk_ptr_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + GetUser("bulk_ptr_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + GetUser("bulk_ptr_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + GetUser("bulk_ptr_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + GetUser("bulk_ptr_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + GetUser("bulk_ptr_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + GetUser("bulk_ptr_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}), + GetUser("bulk_ptr_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + CheckUser(t, *user, *user) + } + + var users2 []User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs) + for idx, user := range users2 { + CheckUser(t, user, *users[idx]) + } +} + +func TestPolymorphicHasOne(t *testing.T) { + t.Run("Struct", func(t *testing.T) { + pet := Pet{ + Name: "PolymorphicHasOne", + Toy: Toy{Name: "Toy-PolymorphicHasOne"}, + } + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckPet(t, pet, pet) + + var pet2 Pet + DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) + CheckPet(t, pet2, pet) + }) + + t.Run("Slice", func(t *testing.T) { + pets := []Pet{{ + Name: "PolymorphicHasOne-Slice-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, + }, { + Name: "PolymorphicHasOne-Slice-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-2"}, + }, { + Name: "PolymorphicHasOne-Slice-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var petIDs []uint + for _, pet := range pets { + petIDs = append(petIDs, pet.ID) + CheckPet(t, pet, pet) + } + + var pets2 []Pet + DB.Preload("Toy").Find(&pets2, "id IN ?", petIDs) + for idx, pet := range pets2 { + CheckPet(t, pet, pets[idx]) + } + }) + + t.Run("SliceOfPtr", func(t *testing.T) { + pets := []*Pet{{ + Name: "PolymorphicHasOne-Slice-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, + }, { + Name: "PolymorphicHasOne-Slice-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-2"}, + }, { + Name: "PolymorphicHasOne-Slice-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + CheckPet(t, *pet, *pet) + } + }) + + t.Run("Array", func(t *testing.T) { + pets := [...]Pet{{ + Name: "PolymorphicHasOne-Array-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, + }, { + Name: "PolymorphicHasOne-Array-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-2"}, + }, { + Name: "PolymorphicHasOne-Array-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + CheckPet(t, pet, pet) + } + }) + + t.Run("ArrayPtr", func(t *testing.T) { + pets := [...]*Pet{{ + Name: "PolymorphicHasOne-Array-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, + }, { + Name: "PolymorphicHasOne-Array-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-2"}, + }, { + Name: "PolymorphicHasOne-Array-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + CheckPet(t, *pet, *pet) + } + }) +} + +func TestCreateEmptyStruct(t *testing.T) { + type EmptyStruct struct { + ID uint + } + DB.Migrator().DropTable(&EmptyStruct{}) + + if err := DB.AutoMigrate(&EmptyStruct{}); err != nil { + t.Errorf("no error should happen when auto migrate, but got %v", err) + } + + if err := DB.Create(&EmptyStruct{}).Error; err != nil { + t.Errorf("No error should happen when creating user, but got %v", err) + } +} + +func TestCreateEmptySlice(t *testing.T) { + data := []User{} + if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } + + sliceMap := []map[string]interface{}{} + if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } +} + +func TestCreateInvalidSlice(t *testing.T) { + users := []*User{ + GetUser("invalid_slice_1", Config{}), + GetUser("invalid_slice_2", Config{}), + nil, + } + + if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error invalid data when creating from slice that contains invalid data") + } +} + +func TestCreateWithExistingTimestamp(t *testing.T) { + user := User{Name: "CreateUserExistingTimestamp"} + curTime := now.MustParse("2016-01-01") + user.CreatedAt = curTime + user.UpdatedAt = curTime + DB.Save(&user) + + AssertEqual(t, user.CreatedAt, curTime) + AssertEqual(t, user.UpdatedAt, curTime) + + var newUser User + DB.First(&newUser, user.ID) + + AssertEqual(t, newUser.CreatedAt, curTime) + AssertEqual(t, newUser.UpdatedAt, curTime) +} + +func TestCreateWithNowFuncOverride(t *testing.T) { + user := User{Name: "CreateUserTimestampOverride"} + curTime := now.MustParse("2016-01-01") + + NEW := DB.Session(&gorm.Session{ + NowFunc: func() time.Time { + return curTime + }, + }) + + NEW.Save(&user) + + AssertEqual(t, user.CreatedAt, curTime) + AssertEqual(t, user.UpdatedAt, curTime) + + var newUser User + NEW.First(&newUser, user.ID) + + AssertEqual(t, newUser.CreatedAt, curTime) + AssertEqual(t, newUser.UpdatedAt, curTime) +} + +func TestCreateWithNoGORMPrimaryKey(t *testing.T) { + type JoinTable struct { + UserID uint + FriendID uint + } + + DB.Migrator().DropTable(&JoinTable{}) + if err := DB.AutoMigrate(&JoinTable{}); err != nil { + t.Errorf("no error should happen when auto migrate, but got %v", err) + } + + jt := JoinTable{UserID: 1, FriendID: 2} + err := DB.Create(&jt).Error + if err != nil { + t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) + } +} + +func TestSelectWithCreate(t *testing.T) { + user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active").Create(&user) + + var user2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) + + user.Birthday = nil + user.Pets = nil + user.Company = Company{} + user.Team = nil + user.Friends = nil + + CheckUser(t, user2, user) +} + +func TestOmitWithCreate(t *testing.T) { + user := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Omit("Account", "Toys", "Manager", "Birthday").Create(&user) + + var result User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result, user.ID) + + user.Birthday = nil + user.Account = Account{} + user.Toys = nil + user.Manager = nil + + CheckUser(t, result, user) + + user2 := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Omit(clause.Associations).Create(&user2) + + var result2 User + DB.Preload(clause.Associations).First(&result2, user2.ID) + + user2.Account = Account{} + user2.Toys = nil + user2.Manager = nil + user2.Company = Company{} + user2.Pets = nil + user2.Team = nil + user2.Languages = nil + user2.Friends = nil + + CheckUser(t, result2, user2) +} + +func TestFirstOrCreateNotExistsTable(t *testing.T) { + company := Company{Name: "first_or_create_if_not_exists_table"} + if err := DB.Table("not_exists").FirstOrCreate(&company).Error; err == nil { + t.Errorf("not exists table, but err is nil") + } +} + +func TestFirstOrCreateWithPrimaryKey(t *testing.T) { + company := Company{ID: 100, Name: "company100_with_primarykey"} + DB.FirstOrCreate(&company) + + if company.ID != 100 { + t.Errorf("invalid primary key after creating, got %v", company.ID) + } + + companies := []Company{ + {ID: 101, Name: "company101_with_primarykey"}, + {ID: 102, Name: "company102_with_primarykey"}, + } + DB.Create(&companies) + + if companies[0].ID != 101 || companies[1].ID != 102 { + t.Errorf("invalid primary key after creating, got %v, %v", companies[0].ID, companies[1].ID) + } +} + +func TestCreateFromSubQuery(t *testing.T) { + user := User{Name: "jinzhu"} + + DB.Create(&user) + + subQuery := DB.Table("users").Where("name=?", user.Name).Select("id") + + result := DB.Session(&gorm.Session{DryRun: true}).Model(&Pet{}).Create([]map[string]interface{}{ + { + "name": "cat", + "user_id": gorm.Expr("(?)", DB.Table("(?) as tmp", subQuery).Select("@uid:=id")), + }, + { + "name": "dog", + "user_id": gorm.Expr("@uid"), + }, + }) + + if !regexp.MustCompile(`INSERT INTO .pets. \(.name.,.user_id.\) .*VALUES \(.+,\(SELECT @uid:=id FROM \(SELECT id FROM .users. WHERE name=.+\) as tmp\)\),\(.+,@uid\)`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid insert SQL, got %v", result.Statement.SQL.String()) + } +} + +func TestCreateNilPointer(t *testing.T) { + var user *User + + err := DB.Create(user).Error + if err == nil || err != gorm.ErrInvalidValue { + t.Fatalf("it is not ErrInvalidValue") + } +} + +func TestFirstOrCreateRowsAffected(t *testing.T) { + user := User{Name: "TestFirstOrCreateRowsAffected"} + + res := DB.FirstOrCreate(&user, "name = ?", user.Name) + if res.Error != nil || res.RowsAffected != 1 { + t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) + } + + res = DB.FirstOrCreate(&user, "name = ?", user.Name) + if res.Error != nil || res.RowsAffected != 0 { + t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) + } +} + +func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { + type CompositeKeyProduct struct { + ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key + LanguageCode int `gorm:"primaryKey;"` // primary key + Code string + Name string + } + + if err := DB.Migrator().DropTable(&CompositeKeyProduct{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + prod := &CompositeKeyProduct{ + LanguageCode: 56, + Code: "Code56", + Name: "ProductName56", + } + if err := DB.Create(&prod).Error; err != nil { + t.Fatalf("failed to create, got error %v", err) + } + + newProd := &CompositeKeyProduct{} + if err := DB.First(&newProd).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name") + } +} + +func TestCreateOnConfilctWithDefalutNull(t *testing.T) { + type OnConfilctUser struct { + ID string + Name string `gorm:"default:null"` + Email string + Mobile string `gorm:"default:'133xxxx'"` + } + + err := DB.Migrator().DropTable(&OnConfilctUser{}) + AssertEqual(t, err, nil) + err = DB.AutoMigrate(&OnConfilctUser{}) + AssertEqual(t, err, nil) + + u := OnConfilctUser{ + ID: "on-confilct-user-id", + Name: "on-confilct-user-name", + Email: "on-confilct-user-email", + Mobile: "on-confilct-user-mobile", + } + err = DB.Create(&u).Error + AssertEqual(t, err, nil) + + u.Name = "on-confilct-user-name-2" + u.Email = "on-confilct-user-email-2" + u.Mobile = "" + err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error + AssertEqual(t, err, nil) + + var u2 OnConfilctUser + err = DB.Where("id = ?", u.ID).First(&u2).Error + AssertEqual(t, err, nil) + AssertEqual(t, u2.Name, "on-confilct-user-name-2") + AssertEqual(t, u2.Email, "on-confilct-user-email-2") + AssertEqual(t, u2.Mobile, "133xxxx") +} diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go new file mode 100644 index 00000000..7802eb11 --- /dev/null +++ b/tests/customize_field_test.go @@ -0,0 +1,192 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestCustomizeColumn(t *testing.T) { + type CustomizeColumn struct { + ID int64 `gorm:"column:mapped_id; primary_key:yes"` + Name string `gorm:"column:mapped_name"` + Date *time.Time `gorm:"column:mapped_time"` + } + + DB.Migrator().DropTable(&CustomizeColumn{}) + DB.AutoMigrate(&CustomizeColumn{}) + + expected := "foo" + now := time.Now() + cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} + + if count := DB.Create(&cc).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + + var cc1 CustomizeColumn + DB.First(&cc1, "mapped_name = ?", "foo") + + if cc1.Name != expected { + t.Errorf("Failed to query CustomizeColumn") + } + + cc.Name = "bar" + DB.Save(&cc) + + var cc2 CustomizeColumn + DB.First(&cc2, "mapped_id = ?", 666) + if cc2.Name != "bar" { + t.Errorf("Failed to query CustomizeColumn") + } +} + +func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { + // Make sure an ignored field does not interfere with another field's custom + // column name that matches the ignored field. + type CustomColumnAndIgnoredFieldClash struct { + Body string `gorm:"-"` + RawBody string `gorm:"column:body"` + } + + DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{}) + + if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil { + t.Errorf("Should not raise error: %v", err) + } +} + +func TestCustomizeField(t *testing.T) { + type CustomizeFieldStruct struct { + gorm.Model + Name string + FieldAllowCreate string `gorm:"<-:create"` + FieldAllowUpdate string `gorm:"<-:update"` + FieldAllowSave string `gorm:"<-"` + FieldAllowSave2 string `gorm:"<-:create,update"` + FieldAllowSave3 string `gorm:"->:false;<-:create"` + FieldReadonly string `gorm:"->"` + FieldIgnore string `gorm:"-"` + AutoUnixCreateTime int32 `gorm:"autocreatetime"` + AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"` + AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` + AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"` + AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"` + AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"` + } + + DB.Migrator().DropTable(&CustomizeFieldStruct{}) + + if err := DB.AutoMigrate(&CustomizeFieldStruct{}); err != nil { + t.Errorf("Failed to migrate, got error: %v", err) + } + + if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "FieldIgnore") { + t.Errorf("FieldIgnore should not be created") + } + + if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "field_ignore") { + t.Errorf("FieldIgnore should not be created") + } + + generateStruct := func(name string) *CustomizeFieldStruct { + return &CustomizeFieldStruct{ + Name: name, + FieldAllowCreate: name + "_allow_create", + FieldAllowUpdate: name + "_allow_update", + FieldAllowSave: name + "_allow_save", + FieldAllowSave2: name + "_allow_save2", + FieldAllowSave3: name + "_allow_save3", + FieldReadonly: name + "_allow_readonly", + FieldIgnore: name + "_allow_ignore", + } + } + + create := generateStruct("create") + DB.Create(&create) + + var result CustomizeFieldStruct + DB.Find(&result, "name = ?", "create") + + AssertObjEqual(t, result, create, "Name", "FieldAllowCreate", "FieldAllowSave", "FieldAllowSave2") + + if result.FieldAllowUpdate != "" || result.FieldReadonly != "" || result.FieldIgnore != "" || result.FieldAllowSave3 != "" { + t.Fatalf("invalid result: %#v", result) + } + + if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 { + t.Fatalf("invalid create/update unix time: %#v", result) + } + + if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 { + t.Fatalf("invalid create/update unix milli time: %#v", result) + } + + if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 { + t.Fatalf("invalid create/update unix nano time: %#v", result) + } + + result.FieldAllowUpdate = "field_allow_update_updated" + result.FieldReadonly = "field_readonly_updated" + result.FieldIgnore = "field_ignore_updated" + DB.Save(&result) + + var result2 CustomizeFieldStruct + DB.Find(&result2, "name = ?", "create") + + if result2.FieldAllowUpdate != result.FieldAllowUpdate || result2.FieldReadonly != "" || result2.FieldIgnore != "" { + t.Fatalf("invalid updated result: %#v", result2) + } + + if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: create.FieldReadonly, FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err == nil { + t.Fatalf("Should failed to find result") + } + + if err := DB.Table("customize_field_structs").Where("1 = 1").UpdateColumn("field_readonly", "readonly").Error; err != nil { + t.Fatalf("failed to update field_readonly column") + } + + if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: "readonly", FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err != nil { + t.Fatalf("Should find result") + } + + var result3 CustomizeFieldStruct + DB.Find(&result3, "name = ?", "create") + + if result3.FieldReadonly != "readonly" { + t.Fatalf("invalid updated result: %#v", result3) + } + + var result4 CustomizeFieldStruct + if err := DB.First(&result4, "field_allow_save3 = ?", create.FieldAllowSave3).Error; err != nil { + t.Fatalf("failed to query with inserted field, got error %v", err) + } + + AssertEqual(t, result3, result4) + + createWithDefaultTime := generateStruct("create_with_default_time") + createWithDefaultTime.AutoUnixCreateTime = 100 + createWithDefaultTime.AutoUnixUpdateTime = 100 + createWithDefaultTime.AutoUnixMilliCreateTime = 100 + createWithDefaultTime.AutoUnixMilliUpdateTime = 100 + createWithDefaultTime.AutoUnixNanoCreateTime = 100 + createWithDefaultTime.AutoUnixNanoUpdateTime = 100 + DB.Create(&createWithDefaultTime) + + var createWithDefaultTimeResult CustomizeFieldStruct + DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) + + if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { + t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) + } + + if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { + t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) + } + + if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { + t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) + } +} diff --git a/tests/default_value_test.go b/tests/default_value_test.go new file mode 100644 index 00000000..918f0796 --- /dev/null +++ b/tests/default_value_test.go @@ -0,0 +1,41 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/gorm" +) + +func TestDefaultValue(t *testing.T) { + type Harumph struct { + gorm.Model + Email string `gorm:"not null;index:,unique"` + Name string `gorm:"notNull;default:foo"` + Name2 string `gorm:"size:233;not null;default:'foo'"` + Name3 string `gorm:"size:233;notNull;default:''"` + Age int `gorm:"default:18"` + Created time.Time `gorm:"default:2000-01-02"` + Enabled bool `gorm:"default:true"` + } + + DB.Migrator().DropTable(&Harumph{}) + + if err := DB.AutoMigrate(&Harumph{}); err != nil { + t.Fatalf("Failed to migrate with default value, got error: %v", err) + } + + harumph := Harumph{Email: "hello@gorm.io"} + if err := DB.Create(&harumph).Error; err != nil { + t.Fatalf("Failed to create data with default value, got error: %v", err) + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled || harumph.Created.Format("20060102") != "20000102" { + t.Fatalf("Failed to create data with default value, got: %+v", harumph) + } + + var result Harumph + if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { + t.Fatalf("Failed to find created data, got error: %v", err) + } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" { + t.Fatalf("Failed to find created data with default data, got %+v", result) + } +} diff --git a/tests/delete_test.go b/tests/delete_test.go new file mode 100644 index 00000000..5cb4b91e --- /dev/null +++ b/tests/delete_test.go @@ -0,0 +1,258 @@ +package tests_test + +import ( + "errors" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestDelete(t *testing.T) { + users := []User{*GetUser("delete", Config{}), *GetUser("delete", Config{}), *GetUser("delete", Config{})} + + if err := DB.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("user's primary key should has value after create, got : %v", user.ID) + } + } + + if res := DB.Delete(&users[1]); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("errors happened when delete: %v, affected: %v", res.Error, res.RowsAffected) + } + + var result User + if err := DB.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } + + for _, user := range []User{users[0], users[2]} { + result = User{} + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + + for _, user := range []User{users[0], users[2]} { + result = User{} + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + + if err := DB.Delete(&users[0]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + if err := DB.Delete(&User{}).Error; err != gorm.ErrMissingWhereClause { + t.Errorf("errors happened when delete: %v", err) + } + + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } +} + +func TestDeleteWithTable(t *testing.T) { + type UserWithDelete struct { + gorm.Model + Name string + } + + DB.Table("deleted_users").Migrator().DropTable(UserWithDelete{}) + DB.Table("deleted_users").AutoMigrate(UserWithDelete{}) + + user := UserWithDelete{Name: "delete1"} + DB.Table("deleted_users").Create(&user) + + var result UserWithDelete + if err := DB.Table("deleted_users").First(&result).Error; err != nil { + t.Errorf("failed to find deleted user, got error %v", err) + } + + AssertEqual(t, result, user) + + if err := DB.Table("deleted_users").Delete(&result).Error; err != nil { + t.Errorf("failed to delete user, got error %v", err) + } + + var result2 UserWithDelete + if err := DB.Table("deleted_users").First(&result2, user.ID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should raise record not found error, but got error %v", err) + } + + var result3 UserWithDelete + if err := DB.Table("deleted_users").Unscoped().First(&result3, user.ID).Error; err != nil { + t.Fatalf("failed to find record, got error %v", err) + } + + if err := DB.Table("deleted_users").Unscoped().Delete(&result).Error; err != nil { + t.Errorf("failed to delete user with unscoped, got error %v", err) + } + + var result4 UserWithDelete + if err := DB.Table("deleted_users").Unscoped().First(&result4, user.ID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should raise record not found error, but got error %v", err) + } +} + +func TestInlineCondDelete(t *testing.T) { + user1 := *GetUser("inline_delete_1", Config{}) + user2 := *GetUser("inline_delete_2", Config{}) + DB.Save(&user1).Save(&user2) + + if DB.Delete(&User{}, user1.ID).Error != nil { + t.Errorf("No error should happen when delete a record") + } else if err := DB.Where("name = ?", user1.Name).First(&User{}).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("User can't be found after delete") + } + + if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Errorf("No error should happen when delete a record, err=%s", err) + } else if err := DB.Where("name = ?", user2.Name).First(&User{}).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("User can't be found after delete") + } +} + +func TestBlockGlobalDelete(t *testing.T) { + if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while deleting error") + } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } +} + +func TestDeleteWithAssociations(t *testing.T) { + user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + +func TestDeleteAssociationsWithUnscoped(t *testing.T) { + user := GetUser("unscoped_delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Unscoped().Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + +func TestDeleteSliceWithAssociations(t *testing.T) { + users := []User{ + *GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), + *GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}), + *GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}), + *GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}), + } + + if err := DB.Create(users).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + +// only sqlite, postgres support returning +func TestSoftDeleteReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + users := []*User{ + GetUser("delete-returning-1", Config{}), + GetUser("delete-returning-2", Config{}), + GetUser("delete-returning-3", Config{}), + } + DB.Create(&users) + + var results []User + DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results) + if len(results) != 2 { + t.Errorf("failed to return delete data, got %v", results) + } + + var count int64 + DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count) + if count != 1 { + t.Errorf("failed to delete data, current count %v", count) + } +} + +func TestDeleteReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + companies := []Company{ + {Name: "delete-returning-1"}, + {Name: "delete-returning-2"}, + {Name: "delete-returning-3"}, + } + DB.Create(&companies) + + var results []Company + DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results) + if len(results) != 2 { + t.Errorf("failed to return delete data, got %v", results) + } + + var count int64 + DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count) + if count != 1 { + t.Errorf("failed to delete data, current count %v", count) + } +} diff --git a/tests/distinct_test.go b/tests/distinct_test.go new file mode 100644 index 00000000..8c8298ae --- /dev/null +++ b/tests/distinct_test.go @@ -0,0 +1,74 @@ +package tests_test + +import ( + "regexp" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestDistinct(t *testing.T) { + users := []User{ + *GetUser("distinct", Config{}), + *GetUser("distinct", Config{}), + *GetUser("distinct", Config{}), + *GetUser("distinct-2", Config{}), + *GetUser("distinct-3", Config{}), + } + users[0].Age = 20 + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + var names []string + DB.Table("users").Where("name like ?", "distinct%").Order("name").Pluck("name", &names) + AssertEqual(t, names, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) + + var names1 []string + DB.Model(&User{}).Where("name like ?", "distinct%").Distinct().Order("name").Pluck("Name", &names1) + + AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"}) + + var names2 []string + DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table("users") + }).Where("name like ?", "distinct%").Order("name").Pluck("name", &names2) + AssertEqual(t, names2, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) + + var results []User + if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil { + t.Errorf("failed to query users, got error: %v", err) + } + + expects := []User{ + {Name: "distinct", Age: 20}, + {Name: "distinct", Age: 18}, + {Name: "distinct-2", Age: 18}, + {Name: "distinct-3", Age: 18}, + } + + if len(results) != 4 { + t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results)) + } + + for idx, expect := range expects { + AssertObjEqual(t, results[idx], expect, "Name", "Age") + } + + var count int64 + if err := DB.Model(&User{}).Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 5 { + t.Errorf("failed to query users count, got error: %v, count: %v", err, count) + } + + if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 { + t.Errorf("failed to query users count, got error: %v, count %v", err, count) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + r := dryDB.Distinct("u.id, u.*").Table("user_speaks as s").Joins("inner join users as u on u.id = s.user_id").Where("s.language_code ='US' or s.language_code ='ES'").Find(&User{}) + if !regexp.MustCompile(`SELECT DISTINCT u\.id, u\.\* FROM user_speaks as s inner join users as u`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Distinct with u.*, but got %v", r.Statement.SQL.String()) + } +} diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml new file mode 100644 index 00000000..866a4d62 --- /dev/null +++ b/tests/docker-compose.yml @@ -0,0 +1,36 @@ +version: '3' + +services: + mysql: + image: 'mysql/mysql-server:latest' + ports: + - "9910:3306" + environment: + - MYSQL_DATABASE=gorm + - MYSQL_USER=gorm + - MYSQL_PASSWORD=gorm + - MYSQL_RANDOM_ROOT_PASSWORD="yes" + postgres: + image: 'postgres:latest' + ports: + - "9920:5432" + environment: + - TZ=Asia/Shanghai + - POSTGRES_DB=gorm + - POSTGRES_USER=gorm + - POSTGRES_PASSWORD=gorm + mssql: + image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest' + ports: + - "9930:1433" + environment: + - ACCEPT_EULA=Y + - SA_PASSWORD=LoremIpsum86 + - MSSQL_DB=gorm + - MSSQL_USER=gorm + - MSSQL_PASSWORD=LoremIpsum86 + tidb: + image: 'pingcap/tidb:v6.5.0' + ports: + - "9940:4000" + command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 & diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go new file mode 100644 index 00000000..3747dad9 --- /dev/null +++ b/tests/embedded_struct_test.go @@ -0,0 +1,221 @@ +package tests_test + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestEmbeddedStruct(t *testing.T) { + type ReadOnly struct { + ReadOnly *bool + } + + type BasePost struct { + Id int64 + Title string + URL string + ReadOnly + } + + type Author struct { + ID string + Name string + Email string + } + + type HNPost struct { + BasePost + Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct + Upvotes int32 + } + + type EngadgetPost struct { + BasePost BasePost `gorm:"Embedded"` + Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct + ImageUrl string + } + + DB.Migrator().DropTable(&HNPost{}, &EngadgetPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}, &EngadgetPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + for _, name := range []string{"author_id", "author_name", "author_email"} { + if !DB.Migrator().HasColumn(&EngadgetPost{}, name) { + t.Errorf("should has prefixed column %v", name) + } + } + + stmt := gorm.Statement{DB: DB} + if err := stmt.Parse(&EngadgetPost{}); err != nil { + t.Fatalf("failed to parse embedded struct") + } else if len(stmt.Schema.PrimaryFields) != 1 { + t.Errorf("should have only one primary field with embedded struct, but got %v", len(stmt.Schema.PrimaryFields)) + } + + for _, name := range []string{"user_id", "user_name", "user_email"} { + if !DB.Migrator().HasColumn(&HNPost{}, name) { + t.Errorf("should has prefixed column %v", name) + } + } + + // save embedded struct + DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) + DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) + var news HNPost + if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if news.Title != "hn_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } + + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}}) + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}}) + var egNews EngadgetPost + if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if egNews.BasePost.Title != "engadget_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } + + var egPosts []EngadgetPost + if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil { + t.Fatalf("no error should happen when query with embedded struct, but got %v", err) + } + expectAuthors := []string{"Edward", "George"} + for i, post := range egPosts { + t.Log(i, post.Author) + if want := expectAuthors[i]; post.Author.Name != want { + t.Errorf("expected author %s got %s", want, post.Author.Name) + } + } +} + +func TestEmbeddedPointerTypeStruct(t *testing.T) { + type BasePost struct { + Id int64 + Title string + URL string + } + + type Author struct { + ID string + Name string + Email string + Age int + } + + type HNPost struct { + *BasePost + Upvotes int32 + *Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct + } + + DB.Migrator().DropTable(&HNPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) + + var hnPost HNPost + if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != "embedded_pointer_type" { + t.Errorf("Should find correct value for embedded pointer type") + } + + if hnPost.Author != nil { + t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author) + } +} + +type Content struct { + Content interface{} `gorm:"type:String"` +} + +func (c Content) Value() (driver.Value, error) { + return json.Marshal(c) +} + +func (c *Content) Scan(src interface{}) error { + b, ok := src.([]byte) + if !ok { + return errors.New("Embedded.Scan byte assertion failed") + } + + var value Content + if err := json.Unmarshal(b, &value); err != nil { + return err + } + + *c = value + + return nil +} + +func TestEmbeddedScanValuer(t *testing.T) { + type HNPost struct { + gorm.Model + Content + } + + DB.Migrator().DropTable(&HNPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + hnPost := HNPost{Content: Content{Content: "hello world"}} + + if err := DB.Create(&hnPost).Error; err != nil { + t.Errorf("Failed to create got error %v", err) + } +} + +func TestEmbeddedRelations(t *testing.T) { + type AdvancedUser struct { + User `gorm:"embedded"` + Advanced bool + } + + DB.Migrator().DropTable(&AdvancedUser{}) + + if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { + if DB.Dialector.Name() != "sqlite" { + t.Errorf("Failed to auto migrate advanced user, got error %v", err) + } + } +} + +func TestEmbeddedTagSetting(t *testing.T) { + type Tag1 struct { + Id int64 `gorm:"autoIncrement"` + } + type Tag2 struct { + Id int64 + } + + type EmbeddedTag struct { + Tag1 Tag1 `gorm:"Embedded;"` + Tag2 Tag2 `gorm:"Embedded;EmbeddedPrefix:t2_"` + Name string + } + + DB.Migrator().DropTable(&EmbeddedTag{}) + err := DB.Migrator().AutoMigrate(&EmbeddedTag{}) + AssertEqual(t, err, nil) + + t1 := EmbeddedTag{Name: "embedded_tag"} + err = DB.Save(&t1).Error + AssertEqual(t, err, nil) + if t1.Tag1.Id == 0 { + t.Errorf("embedded struct's primary field should be rewrited") + } +} diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go new file mode 100644 index 00000000..ead26fce --- /dev/null +++ b/tests/error_translator_test.go @@ -0,0 +1,29 @@ +package tests_test + +import ( + "errors" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/utils/tests" +) + +func TestDialectorWithErrorTranslatorSupport(t *testing.T) { + // it shouldn't translate error when the TranslateError flag is false + translatedErr := errors.New("translated error") + untranslatedErr := errors.New("some random error") + db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) + + err := db.AddError(untranslatedErr) + if errors.Is(err, translatedErr) { + t.Fatalf("expected err: %v got err: %v", translatedErr, err) + } + + // it should translate error when the TranslateError flag is true + db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}, &gorm.Config{TranslateError: true}) + + err = db.AddError(untranslatedErr) + if !errors.Is(err, translatedErr) { + t.Fatalf("expected err: %v got err: %v", translatedErr, err) + } +} diff --git a/tests/go.mod b/tests/go.mod new file mode 100644 index 00000000..edb715d5 --- /dev/null +++ b/tests/go.mod @@ -0,0 +1,17 @@ +module gorm.io/gorm/tests + +go 1.16 + +require ( + github.com/google/uuid v1.3.0 + github.com/jinzhu/now v1.1.5 + github.com/lib/pq v1.10.8 + github.com/mattn/go-sqlite3 v1.14.16 // indirect + gorm.io/driver/mysql v1.5.0 + gorm.io/driver/postgres v1.5.2 + gorm.io/driver/sqlite v1.5.0 + gorm.io/driver/sqlserver v1.4.3 + gorm.io/gorm v1.25.0 +) + +replace gorm.io/gorm => ../ diff --git a/tests/gorm_test.go b/tests/gorm_test.go new file mode 100644 index 00000000..9827465c --- /dev/null +++ b/tests/gorm_test.go @@ -0,0 +1,93 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" +) + +func TestReturningWithNullToZeroValues(t *testing.T) { + dialect := DB.Dialector.Name() + switch dialect { + case "mysql", "sqlserver": + // these dialects do not support the "returning" clause + return + default: + // This user struct will leverage the existing users table, but override + // the Name field to default to null. + type user struct { + gorm.Model + Name string `gorm:"default:null"` + } + u1 := user{} + + if results := DB.Create(&u1); results.Error != nil { + t.Fatalf("errors happened on create: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if u1.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u1.ID) + } + + got := user{} + results := DB.First(&got, "id = ?", u1.ID) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if got.ID != u1.ID { + t.Fatalf("first expects: %v, got %v", u1, got) + } + + results = DB.Select("id, name").Find(&got) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if got.ID != u1.ID { + t.Fatalf("select expects: %v, got %v", u1, got) + } + + u1.Name = "jinzhu" + if results := DB.Save(&u1); results.Error != nil { + t.Fatalf("errors happened on update: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + u1 = user{} // important to reinitialize this before creating it again + u2 := user{} + db := DB.Session(&gorm.Session{CreateBatchSize: 10}) + + if results := db.Create([]*user{&u1, &u2}); results.Error != nil { + t.Fatalf("errors happened on create: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if u1.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u1.ID) + } else if u2.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u2.ID) + } + + var gotUsers []user + results = DB.Where("id in (?, ?)", u1.ID, u2.ID).Order("id asc").Select("id, name").Find(&gotUsers) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 2, results.RowsAffected) + } else if gotUsers[0].ID != u1.ID { + t.Fatalf("select expects: %v, got %v", u1.ID, gotUsers[0].ID) + } else if gotUsers[1].ID != u2.ID { + t.Fatalf("select expects: %v, got %v", u2.ID, gotUsers[1].ID) + } + + u1.Name = "Jinzhu" + u2.Name = "Zhang" + if results := DB.Save([]*user{&u1, &u2}); results.Error != nil { + t.Fatalf("errors happened on update: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + } +} diff --git a/tests/group_by_test.go b/tests/group_by_test.go new file mode 100644 index 00000000..5335fed1 --- /dev/null +++ b/tests/group_by_test.go @@ -0,0 +1,109 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestGroupBy(t *testing.T) { + users := []User{{ + Name: "groupby", + Age: 10, + Birthday: Now(), + Active: true, + }, { + Name: "groupby", + Age: 20, + Birthday: Now(), + }, { + Name: "groupby", + Age: 30, + Birthday: Now(), + Active: true, + }, { + Name: "groupby1", + Age: 110, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 220, + Birthday: Now(), + Active: true, + }, { + Name: "groupby1", + Age: 330, + Birthday: Now(), + Active: true, + }} + + if err := DB.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + var name string + var total int + if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("name").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || total != 60 { + t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) + } + + if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("users.name").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || total != 60 { + t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) + } + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby1" || total != 660 { + t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) + } + + result := struct { + Name string + Total int64 + }{} + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Find(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 660 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Scan(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 660 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + + var active bool + if err := DB.Model(&User{}).Select("name, active, sum(age)").Where("name = ? and active = ?", "groupby", true).Group("name").Group("active").Row().Scan(&name, &active, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || active != true || total != 40 { + t.Errorf("group by two columns, name %v, age %v, active: %v", name, total, active) + } + + if DB.Dialector.Name() == "mysql" { + if err := DB.Model(&User{}).Select("name, age as total").Where("name LIKE ?", "groupby%").Having("total > ?", 300).Scan(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 330 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + } +} diff --git a/tests/helper_test.go b/tests/helper_test.go new file mode 100644 index 00000000..c34e357c --- /dev/null +++ b/tests/helper_test.go @@ -0,0 +1,274 @@ +package tests_test + +import ( + "os" + "sort" + "strconv" + "strings" + "testing" + "time" + + "gorm.io/gorm" + + . "gorm.io/gorm/utils/tests" +) + +type Config struct { + Account bool + Pets int + Toys int + Company bool + Manager bool + Team int + Languages int + Friends int + NamedPet bool +} + +func GetUser(name string, config Config) *User { + var ( + birthday = time.Now().Round(time.Second) + user = User{ + Name: name, + Age: 18, + Birthday: &birthday, + } + ) + + if config.Account { + user.Account = Account{Number: name + "_account"} + } + + for i := 0; i < config.Pets; i++ { + user.Pets = append(user.Pets, &Pet{Name: name + "_pet_" + strconv.Itoa(i+1)}) + } + + for i := 0; i < config.Toys; i++ { + user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) + } + + if config.Company { + user.Company = Company{Name: "company-" + name} + } + + if config.Manager { + user.Manager = GetUser(name+"_manager", Config{}) + } + + for i := 0; i < config.Team; i++ { + user.Team = append(user.Team, *GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) + } + + for i := 0; i < config.Languages; i++ { + name := name + "_locale_" + strconv.Itoa(i+1) + language := Language{Code: name, Name: name} + user.Languages = append(user.Languages, language) + } + + for i := 0; i < config.Friends; i++ { + user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) + } + + if config.NamedPet { + user.NamedPet = &Pet{Name: name + "_namepet"} + } + + return &user +} + +func CheckPetUnscoped(t *testing.T, pet Pet, expect Pet) { + doCheckPet(t, pet, expect, true) +} + +func CheckPet(t *testing.T, pet Pet, expect Pet) { + doCheckPet(t, pet, expect, false) +} + +func doCheckPet(t *testing.T, pet Pet, expect Pet, unscoped bool) { + if pet.ID != 0 { + var newPet Pet + if err := db(unscoped).Where("id = ?", pet.ID).First(&newPet).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + AssertObjEqual(t, newPet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + } + } + + AssertObjEqual(t, pet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + + AssertObjEqual(t, pet.Toy, expect.Toy, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "OwnerID", "OwnerType") + + if expect.Toy.Name != "" && expect.Toy.OwnerType != "pets" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "pets", expect.Toy.OwnerType) + } +} + +func CheckUserUnscoped(t *testing.T, user User, expect User) { + doCheckUser(t, user, expect, true) +} + +func CheckUser(t *testing.T, user User, expect User) { + doCheckUser(t, user, expect, false) +} + +func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { + if user.ID != 0 { + var newUser User + if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } + + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + + t.Run("Account", func(t *testing.T) { + AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + + if user.Account.Number != "" { + if !user.Account.UserID.Valid { + t.Errorf("Account's foreign key should be saved") + } else { + var account Account + db(unscoped).First(&account, "user_id = ?", user.ID) + AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + } + } + }) + + t.Run("Pets", func(t *testing.T) { + if len(user.Pets) != len(expect.Pets) { + t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) + } + + sort.Slice(user.Pets, func(i, j int) bool { + return user.Pets[i].ID > user.Pets[j].ID + }) + + sort.Slice(expect.Pets, func(i, j int) bool { + return expect.Pets[i].ID > expect.Pets[j].ID + }) + + for idx, pet := range user.Pets { + if pet == nil || expect.Pets[idx] == nil { + t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) + } else { + doCheckPet(t, *pet, *expect.Pets[idx], unscoped) + } + } + }) + + t.Run("Toys", func(t *testing.T) { + if len(user.Toys) != len(expect.Toys) { + t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) + } + + sort.Slice(user.Toys, func(i, j int) bool { + return user.Toys[i].ID > user.Toys[j].ID + }) + + sort.Slice(expect.Toys, func(i, j int) bool { + return expect.Toys[i].ID > expect.Toys[j].ID + }) + + for idx, toy := range user.Toys { + if toy.OwnerType != "users" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "users", toy.OwnerType) + } + + AssertObjEqual(t, toy, expect.Toys[idx], "ID", "CreatedAt", "UpdatedAt", "Name", "OwnerID", "OwnerType") + } + }) + + t.Run("Company", func(t *testing.T) { + AssertObjEqual(t, user.Company, expect.Company, "ID", "Name") + }) + + t.Run("Manager", func(t *testing.T) { + if user.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + db(unscoped).First(&manager, "id = ?", *user.ManagerID) + AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + }) + + t.Run("Team", func(t *testing.T) { + if len(user.Team) != len(expect.Team) { + t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) + } + + sort.Slice(user.Team, func(i, j int) bool { + return user.Team[i].ID > user.Team[j].ID + }) + + sort.Slice(expect.Team, func(i, j int) bool { + return expect.Team[i].ID > expect.Team[j].ID + }) + + for idx, team := range user.Team { + AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) + + t.Run("Languages", func(t *testing.T) { + if len(user.Languages) != len(expect.Languages) { + t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) + } + + sort.Slice(user.Languages, func(i, j int) bool { + return strings.Compare(user.Languages[i].Code, user.Languages[j].Code) > 0 + }) + + sort.Slice(expect.Languages, func(i, j int) bool { + return strings.Compare(expect.Languages[i].Code, expect.Languages[j].Code) > 0 + }) + for idx, language := range user.Languages { + AssertObjEqual(t, language, expect.Languages[idx], "Code", "Name") + } + }) + + t.Run("Friends", func(t *testing.T) { + if len(user.Friends) != len(expect.Friends) { + t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) + } + + sort.Slice(user.Friends, func(i, j int) bool { + return user.Friends[i].ID > user.Friends[j].ID + }) + + sort.Slice(expect.Friends, func(i, j int) bool { + return expect.Friends[i].ID > expect.Friends[j].ID + }) + + for idx, friend := range user.Friends { + AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) +} + +func tidbSkip(t *testing.T, reason string) { + if isTiDB() { + t.Skipf("This test case skipped, because of TiDB '%s'", reason) + } +} + +func isTiDB() bool { + return os.Getenv("GORM_DIALECT") == "tidb" +} + +func db(unscoped bool) *gorm.DB { + if unscoped { + return DB.Unscoped() + } else { + return DB + } +} diff --git a/tests/hooks_test.go b/tests/hooks_test.go new file mode 100644 index 00000000..0753dd0b --- /dev/null +++ b/tests/hooks_test.go @@ -0,0 +1,568 @@ +package tests_test + +import ( + "errors" + "reflect" + "strings" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type Product struct { + gorm.Model + Name string + Code string + Price float64 + AfterFindCallTimes int64 + BeforeCreateCallTimes int64 + AfterCreateCallTimes int64 + BeforeUpdateCallTimes int64 + AfterUpdateCallTimes int64 + BeforeSaveCallTimes int64 + AfterSaveCallTimes int64 + BeforeDeleteCallTimes int64 + AfterDeleteCallTimes int64 +} + +func (s *Product) BeforeCreate(tx *gorm.DB) (err error) { + if s.Code == "Invalid" { + err = errors.New("invalid product") + } + s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 + return +} + +func (s *Product) BeforeUpdate(tx *gorm.DB) (err error) { + if s.Code == "dont_update" { + err = errors.New("can't update") + } + s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 + return +} + +func (s *Product) BeforeSave(tx *gorm.DB) (err error) { + if s.Code == "dont_save" { + err = errors.New("can't save") + } + s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 + return +} + +func (s *Product) AfterFind(tx *gorm.DB) (err error) { + s.AfterFindCallTimes = s.AfterFindCallTimes + 1 + return +} + +func (s *Product) AfterCreate(tx *gorm.DB) (err error) { + return tx.Model(s).UpdateColumn("AfterCreateCallTimes", s.AfterCreateCallTimes+1).Error +} + +func (s *Product) AfterUpdate(tx *gorm.DB) (err error) { + s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 + return +} + +func (s *Product) AfterSave(tx *gorm.DB) (err error) { + if s.Code == "after_save_error" { + err = errors.New("can't save") + } + s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 + return +} + +func (s *Product) BeforeDelete(tx *gorm.DB) (err error) { + if s.Code == "dont_delete" { + err = errors.New("can't delete") + } + s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 + return +} + +func (s *Product) AfterDelete(tx *gorm.DB) (err error) { + if s.Code == "after_delete_error" { + err = errors.New("can't delete") + } + s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 + return +} + +func (s *Product) GetCallTimes() []int64 { + return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} +} + +func TestRunCallbacks(t *testing.T) { + DB.Migrator().DropTable(&Product{}) + DB.AutoMigrate(&Product{}) + + p := Product{Code: "unique_code", Price: 100} + DB.Save(&p) + + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { + t.Fatalf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { + t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes()) + } + + p.Price = 200 + DB.Save(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { + t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + var products []Product + DB.Find(&products, "code = ?", "unique_code") + if products[0].AfterFindCallTimes != 2 { + t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes) + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { + t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes()) + } + + DB.Delete(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { + t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { + t.Fatalf("Can't find a deleted record") + } + + beforeCallTimes := p.AfterFindCallTimes + if DB.Where("Code = ?", "unique_code").Find(&p).Error != nil { + t.Fatalf("Find don't raise error when record not found") + } + + if p.AfterFindCallTimes != beforeCallTimes { + t.Fatalf("AfterFind should not be called") + } +} + +func TestCallbacksWithErrors(t *testing.T) { + DB.Migrator().DropTable(&Product{}) + DB.AutoMigrate(&Product{}) + + p := Product{Code: "Invalid", Price: 100} + if DB.Save(&p).Error == nil { + t.Fatalf("An error from before create callbacks happened when create with invalid value") + } + + if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { + t.Fatalf("Should not save record that have errors") + } + + if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { + t.Fatalf("An error from after create callbacks happened when create with invalid value") + } + + p2 := Product{Code: "update_callback", Price: 100} + DB.Save(&p2) + + p2.Code = "dont_update" + if DB.Save(&p2).Error == nil { + t.Fatalf("An error from before update callbacks happened when update with invalid value") + } + + if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { + t.Fatalf("Record Should not be updated due to errors happened in before update callback") + } + + if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { + t.Fatalf("Record Should not be updated due to errors happened in before update callback") + } + + p2.Code = "dont_save" + if DB.Save(&p2).Error == nil { + t.Fatalf("An error from before save callbacks happened when update with invalid value") + } + + p3 := Product{Code: "dont_delete", Price: 100} + DB.Save(&p3) + if DB.Delete(&p3).Error == nil { + t.Fatalf("An error from before delete callbacks happened when delete") + } + + if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { + t.Fatalf("An error from before delete callbacks happened") + } + + p4 := Product{Code: "after_save_error", Price: 100} + DB.Save(&p4) + if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { + t.Fatalf("Record should be reverted if get an error in after save callback") + } + + p5 := Product{Code: "after_delete_error", Price: 100} + DB.Save(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Fatalf("Record should be found") + } + + DB.Delete(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback") + } +} + +type Product2 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string +} + +func (s Product2) BeforeCreate(tx *gorm.DB) (err error) { + if !strings.HasSuffix(s.Name, "_clone") { + newProduft := s + newProduft.Price *= 2 + newProduft.Name += "_clone" + err = tx.Create(&newProduft).Error + } + + if s.Name == "Invalid" { + return errors.New("invalid") + } + + return nil +} + +func (s *Product2) BeforeUpdate(tx *gorm.DB) (err error) { + tx.Statement.Where("owner != ?", "admin") + return +} + +func TestUseDBInHooks(t *testing.T) { + DB.Migrator().DropTable(&Product2{}) + DB.AutoMigrate(&Product2{}) + + product := Product2{Name: "Invalid", Price: 100} + + if err := DB.Create(&product).Error; err == nil { + t.Fatalf("should returns error %v when creating product, but got nil", err) + } + + product2 := Product2{Name: "Nice", Price: 100} + + if err := DB.Create(&product2).Error; err != nil { + t.Fatalf("Failed to create product, got error: %v", err) + } + + var result Product2 + if err := DB.First(&result, "name = ?", "Nice").Error; err != nil { + t.Fatalf("Failed to query product, got error: %v", err) + } + + var resultClone Product2 + if err := DB.First(&resultClone, "name = ?", "Nice_clone").Error; err != nil { + t.Fatalf("Failed to find cloned product, got error: %v", err) + } + + result.Price *= 2 + result.Name += "_clone" + AssertObjEqual(t, result, resultClone, "Price", "Name") + + DB.Model(&result).Update("Price", 500) + var result2 Product2 + DB.First(&result2, "name = ?", "Nice") + + if result2.Price != 500 { + t.Errorf("Failed to update product's price, expects: %v, got %v", 500, result2.Price) + } + + product3 := Product2{Name: "Nice2", Price: 600, Owner: "admin"} + if err := DB.Create(&product3).Error; err != nil { + t.Fatalf("Failed to create product, got error: %v", err) + } + + var result3 Product2 + if err := DB.First(&result3, "name = ?", "Nice2").Error; err != nil { + t.Fatalf("Failed to query product, got error: %v", err) + } + + DB.Model(&result3).Update("Price", 800) + var result4 Product2 + DB.First(&result4, "name = ?", "Nice2") + + if result4.Price != 600 { + t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) + } +} + +type Product3 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string +} + +func (s Product3) BeforeCreate(tx *gorm.DB) (err error) { + tx.Statement.SetColumn("Price", s.Price+100) + return nil +} + +func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) { + if tx.Statement.Changed() { + tx.Statement.SetColumn("Price", s.Price+10) + } + + if tx.Statement.Changed("Code") { + s.Price += 20 + tx.Statement.SetColumn("Price", s.Price+30) + } + return nil +} + +func TestSetColumn(t *testing.T) { + DB.Migrator().DropTable(&Product3{}) + DB.AutoMigrate(&Product3{}) + + product := Product3{Name: "Product", Price: 0} + DB.Create(&product) + + if product.Price != 100 { + t.Errorf("invalid price after create, got %+v", product) + } + + DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"}) + + if product.Price != 150 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code not changed, price should not change + DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"}) + + if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, but not selected, price should not change + DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"}) + + if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, price should changed + DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"}) + + if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result Product3 + DB.First(&result, product.ID) + + AssertEqual(t, result, product) + + // Select to change Code, but nothing updated, price should not change + DB.Model(&product).Select("code").Updates(Product3{Name: "L1214", Code: "L1213"}) + + if product.Price != 220 || product.Code != "L1213" || product.Name != "Product New3" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).Updates(Product3{Code: "L1214"}) + if product.Price != 270 || product.Code != "L1214" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, price should changed + DB.Model(&product).Select("Name", "Code", "Price").Updates(Product3{Name: "Product New4", Code: ""}) + if product.Name != "Product New4" || product.Price != 320 || product.Code != "" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).UpdateColumns(Product3{Code: "L1215"}) + if product.Price != 320 || product.Code != "L1215" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"}) + if product.Price != 320 || product.Code != "L1216" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result2 Product3 + DB.First(&result2, product.ID) + + AssertEqual(t, result2, product) + + product2 := Product3{Name: "Product", Price: 0} + DB.Session(&gorm.Session{SkipHooks: true}).Create(&product2) + + if product2.Price != 0 { + t.Errorf("invalid price after create without hooks, got %+v", product2) + } +} + +func TestHooksForSlice(t *testing.T) { + DB.Migrator().DropTable(&Product3{}) + DB.AutoMigrate(&Product3{}) + + products := []*Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products) + + for idx, value := range []int64{200, 300, 400} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + DB.Model(&products).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + products2 := []Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products2) + + for idx, value := range []int64{200, 300, 400} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } + + DB.Model(&products2).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } +} + +type Product4 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string + Item ProductItem +} + +type ProductItem struct { + gorm.Model + Code string + Product4ID uint + AfterFindCallTimes int +} + +func (pi ProductItem) BeforeCreate(*gorm.DB) error { + if pi.Code == "invalid" { + return errors.New("invalid item") + } + return nil +} + +func (pi *ProductItem) AfterFind(*gorm.DB) error { + pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1 + return nil +} + +func TestFailedToSaveAssociationShouldRollback(t *testing.T) { + DB.Migrator().DropTable(&Product4{}, &ProductItem{}) + DB.AutoMigrate(&Product4{}, &ProductItem{}) + + product := Product4{Name: "Product-1", Price: 100, Item: ProductItem{Code: "invalid"}} + if err := DB.Create(&product).Error; err == nil { + t.Errorf("should got failed to save, but error is nil") + } + + if DB.First(&Product4{}, "name = ?", product.Name).Error == nil { + t.Errorf("should got RecordNotFound, but got nil") + } + + product = Product4{Name: "Product-2", Price: 100, Item: ProductItem{Code: "valid"}} + if err := DB.Create(&product).Error; err != nil { + t.Errorf("should create product, but got error %v", err) + } + + if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } + + var productWithItem Product4 + if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } + + if productWithItem.Item.AfterFindCallTimes != 0 { + t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes) + } +} + +type Product5 struct { + gorm.Model + Name string +} + +var beforeUpdateCall int + +func (p *Product5) BeforeUpdate(*gorm.DB) error { + beforeUpdateCall = beforeUpdateCall + 1 + return nil +} + +func TestUpdateCallbacks(t *testing.T) { + DB.Migrator().DropTable(&Product5{}) + DB.AutoMigrate(&Product5{}) + + p := Product5{Name: "unique_code"} + DB.Model(&Product5{}).Create(&p) + + err := DB.Model(&Product5{}).Where("id", p.ID).Update("name", "update_name_1").Error + if err != nil { + t.Fatalf("should update success, but got err %v", err) + } + if beforeUpdateCall != 1 { + t.Fatalf("before update should be called") + } + + err = DB.Model(Product5{}).Where("id", p.ID).Update("name", "update_name_2").Error + if !errors.Is(err, gorm.ErrInvalidValue) { + t.Fatalf("should got RecordNotFound, but got %v", err) + } + if beforeUpdateCall != 1 { + t.Fatalf("before update should not be called") + } + + err = DB.Model([1]*Product5{&p}).Update("name", "update_name_3").Error + if err != nil { + t.Fatalf("should update success, but got err %v", err) + } + if beforeUpdateCall != 2 { + t.Fatalf("before update should be called") + } + + err = DB.Model([1]Product5{p}).Update("name", "update_name_4").Error + if !errors.Is(err, gorm.ErrInvalidValue) { + t.Fatalf("should got RecordNotFound, but got %v", err) + } + if beforeUpdateCall != 2 { + t.Fatalf("before update should not be called") + } +} diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go new file mode 100644 index 00000000..084c2f2c --- /dev/null +++ b/tests/joins_table_test.go @@ -0,0 +1,116 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type Person struct { + ID int + Name string + Addresses []Address `gorm:"many2many:person_addresses;"` + DeletedAt gorm.DeletedAt +} + +type Address struct { + ID uint + Name string +} + +type PersonAddress struct { + PersonID int + AddressID int + CreatedAt time.Time + DeletedAt gorm.DeletedAt +} + +func TestOverrideJoinTable(t *testing.T) { + DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{}) + + if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil { + t.Fatalf("Failed to setup join table for person, got error %v", err) + } + + if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil { + t.Fatalf("Failed to migrate, got %v", err) + } + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + person := Person{Name: "person", Addresses: []Address{address1, address2}} + DB.Create(&person) + + var addresses1 []Address + if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1)) + } + + if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil { + t.Fatalf("Failed to delete address, got error %v", err) + } + + if len(person.Addresses) != 1 { + t.Fatalf("Should have one address left") + } + + if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 { + t.Fatalf("Should found one address") + } + + var addresses2 []Address + if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2)) + } + + if DB.Model(&person).Association("Addresses").Count() != 1 { + t.Fatalf("Should found one address") + } + + var addresses3 []Address + if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3)) + } + + if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + DB.Model(&person).Association("Addresses").Clear() + + if DB.Model(&person).Association("Addresses").Count() != 0 { + t.Fatalf("Should deleted all addresses") + } + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + DB.Unscoped().Model(&person).Association("Addresses").Clear() + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { + t.Fatalf("address should be deleted when clear with unscoped") + } + + address2_1 := Address{Name: "address 2-1"} + address2_2 := Address{Name: "address 2-2"} + person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}} + DB.Create(&person2) + if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil { + t.Fatalf("failed to delete person, got error: %v", err) + } + + if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 { + t.Errorf("person's addresses expects 2, got %v", count) + } + + if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 { + t.Errorf("person's addresses expects 2, got %v", count) + } +} diff --git a/tests/joins_test.go b/tests/joins_test.go new file mode 100644 index 00000000..786fc37e --- /dev/null +++ b/tests/joins_test.go @@ -0,0 +1,402 @@ +package tests_test + +import ( + "regexp" + "sort" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestJoins(t *testing.T) { + user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) + + DB.Create(&user) + + var user2 User + if err := DB.Joins("NamedPet").Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } + + CheckUser(t, user2, user) +} + +func TestJoinsForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-joins-1", Config{Company: true, Manager: true, Account: true}), + *GetUser("slice-joins-2", Config{Company: true, Manager: true, Account: true}), + *GetUser("slice-joins-3", Config{Company: true, Manager: true, Account: true}), + } + + DB.Create(&users) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + if err := DB.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.id IN ?", userIDs).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + CheckUser(t, user, users2[idx]) + } +} + +func TestJoinConds(t *testing.T) { + user := *GetUser("joins-conds", Config{Account: true, Pets: 3}) + DB.Save(&user) + + var users1 []User + DB.Joins("inner join pets on pets.user_id = users.id").Where("users.name = ?", user.Name).Find(&users1) + if len(users1) != 3 { + t.Errorf("should find two users using left join, but got %v", len(users1)) + } + + var users2 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Where("users.name = ?", user.Name).First(&users2) + if len(users2) != 1 { + t.Errorf("should find one users using left join with conditions, but got %v", len(users2)) + } + + var users3 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where("users.name = ?", user.Name).First(&users3) + if len(users3) != 1 { + t.Errorf("should find one users using multiple left join conditions, but got %v", len(users3)) + } + + var users4 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number+"non-exist").Where("users.name = ?", user.Name).First(&users4) + if len(users4) != 0 { + t.Errorf("should find no user when searching with unexisting credit card, but got %v", len(users4)) + } + + var users5 []User + db5 := DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5) + if db5.Error != nil { + t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) + } + + var users6 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = @Name", user.Pets[0]).Where("users.name = ?", user.Name).First(&users6) + if len(users6) != 1 { + t.Errorf("should find one users using left join with conditions, but got %v", len(users6)) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement + + if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } + + iv := DB.Table(`table_invoices`).Select(`seller, SUM(total) as total, SUM(paid) as paid, SUM(balance) as balance`).Group(`seller`) + stmt = dryDB.Table(`table_employees`).Select(`id, name, iv.total, iv.paid, iv.balance`).Joins(`LEFT JOIN (?) AS iv ON iv.seller = table_employees.id`, iv).Scan(&user).Statement + if !regexp.MustCompile("SELECT id, name, iv.total, iv.paid, iv.balance FROM .table_employees. LEFT JOIN \\(SELECT seller, SUM\\(total\\) as total, SUM\\(paid\\) as paid, SUM\\(balance\\) as balance FROM .table_invoices. GROUP BY .seller.\\) AS iv ON iv.seller = table_employees.id").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } +} + +func TestJoinOn(t *testing.T) { + user := *GetUser("joins-on", Config{Pets: 2}) + DB.Save(&user) + + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-on_pet_1"}) + + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") + + onQuery2 := DB.Where(&Pet{Name: "joins-on_pet_2"}) + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.Name, "joins-on_pet_2") +} + +func TestJoinsWithSelect(t *testing.T) { + type result struct { + ID uint + PetID uint + Name string + } + + user := *GetUser("joins_with_select", Config{Pets: 2}) + DB.Save(&user) + + var results []result + + DB.Table("users").Select("users.id, pets.id as pet_id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) + + sort.Slice(results, func(i, j int) bool { + return results[i].PetID > results[j].PetID + }) + + sort.Slice(results, func(i, j int) bool { + return user.Pets[i].ID > user.Pets[j].ID + }) + + if len(results) != 2 || results[0].Name != user.Pets[0].Name || results[1].Name != user.Pets[1].Name { + t.Errorf("Should find all two pets with Join select, got %+v", results) + } +} + +func TestJoinWithOmit(t *testing.T) { + user := *GetUser("joins_with_omit", Config{Pets: 2}) + DB.Save(&user) + + results := make([]*User, 0) + + if err := DB.Table("users").Omit("name").Where("users.name = ?", "joins_with_omit").Joins("left join pets on pets.user_id = users.id").Find(&results).Error; err != nil { + return + } + + if len(results) != 2 || results[0].Name != "" || results[1].Name != "" { + t.Errorf("Should find all two pets with Join omit and should not find user's name, got %+v", results) + return + } +} + +func TestJoinCount(t *testing.T) { + companyA := Company{Name: "A"} + companyB := Company{Name: "B"} + DB.Create(&companyA) + DB.Create(&companyB) + + user := User{Name: "kingGo", CompanyID: &companyB.ID} + DB.Create(&user) + + query := DB.Model(&User{}).Joins("Company") + // Bug happens when .Count is called on a query. + // Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass. + var total int64 + query.Count(&total) + + var result User + + // Incorrectly generates a 'SELECT *' query which causes companies.id to overwrite users.id + if err := query.First(&result, user.ID).Error; err != nil { + t.Fatalf("Failed, got error: %v", err) + } + + if result.ID != user.ID { + t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) + } +} + +func TestJoinWithSoftDeleted(t *testing.T) { + user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true}) + DB.Create(&user) + + var user1 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user1, user.ID) + if user1.NamedPet == nil || user1.Account.ID == 0 { + t.Fatalf("joins NamedPet and Account should not empty:%v", user1) + } + + // Account should empty + DB.Delete(&user1.Account) + + var user2 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user2, user.ID) + if user2.NamedPet == nil || user2.Account.ID != 0 { + t.Fatalf("joins Account should not empty:%v", user2) + } + + // NamedPet should empty + DB.Delete(&user1.NamedPet) + + var user3 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user3, user.ID) + if user3.NamedPet != nil || user2.Account.ID != 0 { + t.Fatalf("joins NamedPet and Account should not empty:%v", user2) + } +} + +func TestInnerJoins(t *testing.T) { + user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) + + DB.Create(&user) + + var user2 User + var err error + err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, nil) + CheckUser(t, user2, user) + + // inner join and NamedPet is nil + err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, gorm.ErrRecordNotFound) + + // mixed inner join and left join + var user3 User + err = DB.Joins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user3, "users.name = ?", user.Name).Error + AssertEqual(t, err, nil) + CheckUser(t, user3, user) +} + +func TestJoinWithSameColumnName(t *testing.T) { + user := GetUser("TestJoinWithSameColumnName", Config{ + Languages: 1, + Pets: 1, + }) + DB.Create(user) + type UserSpeak struct { + UserID uint + LanguageCode string + } + type Result struct { + User + UserSpeak + Language + Pet + } + + results := make([]Result, 0, 1) + DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id"). + Joins("JOIN languages ON languages.code = user_speaks.language_code"). + Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results) + + if len(results) == 0 { + t.Fatalf("no record find") + } else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID { + t.Fatalf("wrong user id in pet") + } else if results[0].Pet.Name != user.Pets[0].Name { + t.Fatalf("wrong pet name") + } +} + +func TestJoinArgsWithDB(t *testing.T) { + user := *GetUser("joins-args-db", Config{Pets: 2}) + DB.Save(&user) + + // test where + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"}) + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2") + + // test where and omit + onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name") + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID) + AssertEqual(t, user2.NamedPet.Name, "") + + // test where and select + onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name") + var user3 User + if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user3.NamedPet.ID, 0) + AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2") + + // test select + onQuery4 := DB.Select("ID") + var user4 User + if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + if user4.NamedPet.ID == 0 { + t.Fatal("Pet ID can not be empty") + } + AssertEqual(t, user4.NamedPet.Name, "") +} + +func TestNestedJoins(t *testing.T) { + users := []User{ + { + Name: "nested-joins-1", + Manager: &User{ + Name: "nested-joins-manager-1", + Company: Company{ + Name: "nested-joins-manager-company-1", + }, + NamedPet: &Pet{ + Name: "nested-joins-manager-namepet-1", + Toy: Toy{ + Name: "nested-joins-manager-namepet-toy-1", + }, + }, + }, + NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}}, + }, + { + Name: "nested-joins-2", + Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}), + NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}}, + }, + } + + DB.Create(&users) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + if err := DB. + Joins("Manager"). + Joins("Manager.Company"). + Joins("Manager.NamedPet"). + Joins("Manager.NamedPet.Toy"). + Joins("NamedPet"). + Joins("NamedPet.Toy"). + Find(&users2, "users.id IN ?", userIDs).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + // user + CheckUser(t, user, users2[idx]) + if users2[idx].Manager == nil { + t.Fatalf("Failed to load Manager") + } + // manager + CheckUser(t, *user.Manager, *users2[idx].Manager) + // user pet + if users2[idx].NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) + // manager pet + if users2[idx].Manager.NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) + } +} diff --git a/tests/main_test.go b/tests/main_test.go new file mode 100644 index 00000000..997714b9 --- /dev/null +++ b/tests/main_test.go @@ -0,0 +1,53 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestExceptionsWithInvalidSql(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlserver" { + t.Skip("skip sqlserver due to it will raise data race for invalid sql") + } + + var columns []string + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + var count1, count2 int64 + DB.Model(&User{}).Count(&count1) + if count1 <= 0 { + t.Errorf("Should find some users") + } + + if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + DB.Model(&User{}).Count(&count2) + if count1 != count2 { + t.Errorf("No user should not be deleted by invalid SQL") + } +} + +func TestSetAndGet(t *testing.T) { + if value, ok := DB.Set("hello", "world").Get("hello"); !ok { + t.Errorf("Should be able to get setting after set") + } else if value.(string) != "world" { + t.Errorf("Set value should not be changed") + } + + if _, ok := DB.Get("non_existing"); ok { + t.Errorf("Get non existing key should return error") + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go new file mode 100644 index 00000000..69f86412 --- /dev/null +++ b/tests/migrate_test.go @@ -0,0 +1,1600 @@ +package tests_test + +import ( + "context" + "fmt" + "math/rand" + "os" + "reflect" + "strings" + "testing" + "time" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + . "gorm.io/gorm/utils/tests" +) + +func TestMigrate(t *testing.T) { + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") + + if err := DB.Migrator().DropTable(allModels...); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(allModels...); err != nil { + t.Fatalf("Failed to auto migrate, got error %v", err) + } + + if tables, err := DB.Migrator().GetTables(); err != nil { + t.Fatalf("Failed to get database all tables, but got error %v", err) + } else { + for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} { + hasTable := false + for _, t2 := range tables { + if t2 == t1 { + hasTable = true + break + } + } + if !hasTable { + t.Fatalf("Failed to get table %v when GetTables", t1) + } + } + } + + for _, m := range allModels { + if !DB.Migrator().HasTable(m) { + t.Fatalf("Failed to create table for %#v", m) + } + } + + DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table("ccc") + }).Migrator().CreateTable(&Company{}) + + if !DB.Migrator().HasTable("ccc") { + t.Errorf("failed to create table ccc") + } + + for _, indexes := range [][2]string{ + {"user_speaks", "fk_user_speaks_user"}, + {"user_speaks", "fk_user_speaks_language"}, + {"user_friends", "fk_user_friends_user"}, + {"user_friends", "fk_user_friends_friends"}, + {"accounts", "fk_users_account"}, + {"users", "fk_users_team"}, + {"users", "fk_users_company"}, + } { + if !DB.Migrator().HasConstraint(indexes[0], indexes[1]) { + t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) + } + } +} + +func TestAutoMigrateInt8PG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Smallint int8 + + type MigrateInt struct { + Int8 Smallint + } + + tracer := Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) + } + }, + } + + DB.Migrator().DropTable(&MigrateInt{}) + + // The first AutoMigrate to make table with field with correct type + if err := DB.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } + + // make new session to set custom logger tracer + session := DB.Session(&gorm.Session{Logger: tracer}) + + // The second AutoMigrate to catch an error + if err := session.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } +} + +func TestAutoMigrateSelfReferential(t *testing.T) { + type MigratePerson struct { + ID uint + Name string + ManagerID *uint + Manager *MigratePerson + } + + DB.Migrator().DropTable(&MigratePerson{}) + + if err := DB.AutoMigrate(&MigratePerson{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } + + if !DB.Migrator().HasConstraint("migrate_people", "fk_migrate_people_manager") { + t.Fatalf("Failed to find has one constraint between people and managers") + } +} + +func TestSmartMigrateColumn(t *testing.T) { + fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] + + type UserMigrateColumn struct { + ID uint + Name string + Salary float64 + Birthday time.Time `gorm:"precision:4"` + } + + DB.Migrator().DropTable(&UserMigrateColumn{}) + + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + Name string `gorm:"size:128"` + Salary float64 `gorm:"precision:2"` + Birthday time.Time `gorm:"precision:2"` + NameIgnoreMigration string `gorm:"size:100"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { + t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + + type UserMigrateColumn3 struct { + ID uint + Name string `gorm:"size:256"` + Salary float64 `gorm:"precision:3"` + Birthday time.Time `gorm:"precision:3"` + NameIgnoreMigration string `gorm:"size:128;-:migration"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { + t.Fatalf("salary's precision should be 2, but got %v", precision) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + case "name_ignore_migration": + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 { + t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length) + } + } + } +} + +func TestMigrateWithColumnComment(t *testing.T) { + type UserWithColumnComment struct { + gorm.Model + Name string `gorm:"size:111;comment:this is a 字段"` + } + + if err := DB.Migrator().DropTable(&UserWithColumnComment{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&UserWithColumnComment{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } +} + +func TestMigrateWithIndexComment(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type UserWithIndexComment struct { + gorm.Model + Name string `gorm:"size:111;index:,comment:这是一个index"` + } + + if err := DB.Migrator().DropTable(&UserWithIndexComment{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&UserWithIndexComment{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } +} + +func TestMigrateWithUniqueIndex(t *testing.T) { + type UserWithUniqueIndex struct { + ID int + Name string `gorm:"size:20;index:idx_name,unique"` + Date time.Time `gorm:"index:idx_name,unique"` + UName string `gorm:"uniqueIndex;size:255"` + } + + DB.Migrator().DropTable(&UserWithUniqueIndex{}) + if err := DB.AutoMigrate(&UserWithUniqueIndex{}); err != nil { + t.Fatalf("failed to migrate, got %v", err) + } + + if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_name") { + t.Errorf("Failed to find created index") + } + + if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_user_with_unique_indices_u_name") { + t.Errorf("Failed to find created index") + } + + if err := DB.AutoMigrate(&UserWithUniqueIndex{}); err != nil { + t.Fatalf("failed to migrate, got %v", err) + } + + if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_user_with_unique_indices_u_name") { + t.Errorf("Failed to find created index") + } +} + +func TestMigrateTable(t *testing.T) { + type TableStruct struct { + gorm.Model + Name string + } + + DB.Migrator().DropTable(&TableStruct{}) + DB.AutoMigrate(&TableStruct{}) + + if !DB.Migrator().HasTable(&TableStruct{}) { + t.Fatalf("should found created table") + } + + type NewTableStruct struct { + gorm.Model + Name string + } + + if err := DB.Migrator().RenameTable(&TableStruct{}, &NewTableStruct{}); err != nil { + t.Fatalf("Failed to rename table, got error %v", err) + } + + if !DB.Migrator().HasTable("new_table_structs") { + t.Fatal("should found renamed table") + } + + DB.Migrator().DropTable("new_table_structs") + + if DB.Migrator().HasTable(&NewTableStruct{}) { + t.Fatal("should not found dropped table") + } +} + +func TestMigrateWithQuotedIndex(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type QuotedIndexStruct struct { + gorm.Model + Name string `gorm:"size:255;index:AS"` // AS is one of MySQL reserved words + } + + if err := DB.Migrator().DropTable(&QuotedIndexStruct{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&QuotedIndexStruct{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } +} + +func TestMigrateIndexes(t *testing.T) { + type IndexStruct struct { + gorm.Model + Name string `gorm:"size:255;index"` + } + + DB.Migrator().DropTable(&IndexStruct{}) + DB.AutoMigrate(&IndexStruct{}) + + if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { + t.Fatalf("Failed to drop index for user's name, got err %v", err) + } + + if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { + t.Fatalf("Got error when tried to create index: %+v", err) + } + + if !DB.Migrator().HasIndex(&IndexStruct{}, "Name") { + t.Fatalf("Failed to find index for user's name") + } + + if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { + t.Fatalf("Failed to drop index for user's name, got err %v", err) + } + + if DB.Migrator().HasIndex(&IndexStruct{}, "Name") { + t.Fatalf("Should not find index for user's name after delete") + } + + if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { + t.Fatalf("Got error when tried to create index: %+v", err) + } + + if err := DB.Migrator().RenameIndex(&IndexStruct{}, "idx_index_structs_name", "idx_users_name_1"); err != nil { + t.Fatalf("no error should happen when rename index, but got %v", err) + } + + if !DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { + t.Fatalf("Should find index for user's name after rename") + } + + if err := DB.Migrator().DropIndex(&IndexStruct{}, "idx_users_name_1"); err != nil { + t.Fatalf("Failed to drop index for user's name, got err %v", err) + } + + if DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { + t.Fatalf("Should not find index for user's name after delete") + } +} + +func TestTiDBMigrateColumns(t *testing.T) { + if !isTiDB() { + t.Skip() + } + + // TiDB can't change column constraint and has auto_random feature + type ColumnStruct struct { + ID int `gorm:"primarykey;default:auto_random()"` + Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique;comment:my code;"` + Code2 string + Code3 string `gorm:"unique"` + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + type ColumnStruct2 struct { + ID int `gorm:"primarykey;default:auto_random()"` + Name string `gorm:"size:100"` + Code string `gorm:"unique;comment:my code2;default:hello"` + Code2 string `gorm:"comment:my code2;default:hello"` + } + + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { + t.Fatalf("no error should happened when alter column, but got %v", err) + } + + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "name": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + if length, ok := columnType.Length(); !ok || length != 100 { + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + } + case "age": + if v, ok := columnType.DefaultValue(); !ok || v != "18" { + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !ok || v != "my age" { + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code": + if v, ok := columnType.Unique(); !ok || !v { + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.DefaultValue(); !ok || v != "hello" { + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + } + if v, ok := columnType.Comment(); !ok || v != "my code2" { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code2": + // Code2 string `gorm:"comment:my code2;default:hello"` + if v, ok := columnType.DefaultValue(); !ok || v != "hello" { + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + } + if v, ok := columnType.Comment(); !ok || v != "my code2" { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + } + } + } + + type NewColumnStruct struct { + gorm.Model + Name string + NewName string + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Failed to find added column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Found deleted column") + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Failed to found renamed column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Found deleted column") + } +} + +func TestMigrateColumns(t *testing.T) { + tidbSkip(t, "use another test case") + + sqlite := DB.Dialector.Name() == "sqlite" + sqlserver := DB.Dialector.Name() == "sqlserver" + + type ColumnStruct struct { + gorm.Model + Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique;comment:my code;"` + Code2 string + Code3 string `gorm:"unique"` + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + type ColumnStruct2 struct { + gorm.Model + Name string `gorm:"size:100"` + Code string `gorm:"unique;comment:my code2;default:hello"` + Code2 string `gorm:"unique"` + // Code3 string + } + + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { + t.Fatalf("no error should happened when alter column, but got %v", err) + } + + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "name": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + } + case "age": + if v, ok := columnType.DefaultValue(); !ok || v != "18" { + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code": + if v, ok := columnType.Unique(); !ok || !v { + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + } + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code2": + if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code3": + // TODO + // if v, ok := columnType.Unique(); !ok || v { + // t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + // } + } + } + } + + type NewColumnStruct struct { + gorm.Model + Name string + NewName string + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Failed to find added column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Found deleted column") + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Failed to found renamed column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Found deleted column") + } +} + +func TestMigrateConstraint(t *testing.T) { + names := []string{"Account", "fk_users_account", "Pets", "fk_users_pets", "Company", "fk_users_company", "Team", "fk_users_team", "Languages", "fk_users_languages"} + + for _, name := range names { + if !DB.Migrator().HasConstraint(&User{}, name) { + DB.Migrator().CreateConstraint(&User{}, name) + } + + if err := DB.Migrator().DropConstraint(&User{}, name); err != nil { + t.Fatalf("failed to drop constraint %v, got error %v", name, err) + } + + if DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("constraint %v should been deleted", name) + } + + if err := DB.Migrator().CreateConstraint(&User{}, name); err != nil { + t.Fatalf("failed to create constraint %v, got error %v", name, err) + } + + if !DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("failed to found constraint %v", name) + } + } +} + +type DynamicUser struct { + gorm.Model + Name string + CompanyID string `gorm:"index"` +} + +// To test auto migrate crate indexes for dynamic table name +// https://github.com/go-gorm/gorm/issues/4752 +func TestMigrateIndexesWithDynamicTableName(t *testing.T) { + // Create primary table + if err := DB.AutoMigrate(&DynamicUser{}); err != nil { + t.Fatalf("AutoMigrate create table error: %#v", err) + } + + // Create sub tables + for _, v := range []string{"01", "02", "03"} { + tableName := "dynamic_users_" + v + m := DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table(tableName) + }).Migrator() + + if err := m.AutoMigrate(&DynamicUser{}); err != nil { + t.Fatalf("AutoMigrate create table error: %#v", err) + } + + if !m.HasTable(tableName) { + t.Fatalf("AutoMigrate expected %#v exist, but not.", tableName) + } + + if !m.HasIndex(&DynamicUser{}, "CompanyID") { + t.Fatalf("Should have index on %s", "CompanyI.") + } + + if !m.HasIndex(&DynamicUser{}, "DeletedAt") { + t.Fatalf("Should have index on deleted_at.") + } + } +} + +// check column order after migration, flaky test +// https://github.com/go-gorm/gorm/issues/4351 +func TestMigrateColumnOrder(t *testing.T) { + type UserMigrateColumn struct { + ID uint + } + DB.Migrator().DropTable(&UserMigrateColumn{}) + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + F1 string + F2 string + F3 string + F4 string + F5 string + F6 string + F7 string + F8 string + F9 string + F10 string + F11 string + F12 string + F13 string + F14 string + F15 string + F16 string + F17 string + F18 string + F19 string + F20 string + F21 string + F22 string + F23 string + F24 string + F25 string + F26 string + F27 string + F28 string + F29 string + F30 string + F31 string + F32 string + F33 string + F34 string + F35 string + } + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn2{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + typ := reflect.Indirect(reflect.ValueOf(&UserMigrateColumn2{})).Type() + numField := typ.NumField() + if numField != len(columnTypes) { + t.Fatalf("column's number not match struct and ddl, %d != %d", numField, len(columnTypes)) + } + namer := schema.NamingStrategy{} + for i := 0; i < numField; i++ { + expectName := namer.ColumnName("", typ.Field(i).Name) + if columnTypes[i].Name() != expectName { + t.Fatalf("column order not match struct and ddl, idx %d: %s != %s", + i, columnTypes[i].Name(), expectName) + } + } +} + +// https://github.com/go-gorm/gorm/issues/5047 +func TestMigrateSerialColumn(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Event struct { + ID uint `gorm:"primarykey"` + UID uint32 + } + + type Event1 struct { + ID uint `gorm:"primarykey"` + UID uint32 `gorm:"not null;autoIncrement"` + } + + type Event2 struct { + ID uint `gorm:"primarykey"` + UID uint16 `gorm:"not null;autoIncrement"` + } + + var err error + err = DB.Migrator().DropTable(&Event{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + + // create sequence + err = DB.Table("events").AutoMigrate(&Event1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // delete sequence + err = DB.Table("events").AutoMigrate(&Event{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // update sequence + err = DB.Table("events").AutoMigrate(&Event1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = DB.Table("events").AutoMigrate(&Event2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + DB.Table("events").Save(&Event2{}) + DB.Table("events").Save(&Event2{}) + DB.Table("events").Save(&Event2{}) + + events := make([]*Event, 0) + DB.Table("events").Find(&events) + + AssertEqual(t, 3, len(events)) + for _, v := range events { + AssertEqual(t, v.ID, v.UID) + } +} + +// https://github.com/go-gorm/gorm/issues/5300 +func TestMigrateWithSpecialName(t *testing.T) { + var err error + err = DB.AutoMigrate(&Coupon{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + AssertEqual(t, true, DB.Migrator().HasTable("coupons")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) +} + +// https://github.com/go-gorm/gorm/issues/5320 +func TestPrimarykeyID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type MissPKLanguage struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + Name string + } + + type MissPKUser struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"` + } + + var err error + err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("DropTable err:%v", err) + } + + DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`) + + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // patch + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } +} + +func TestCurrentTimestamp(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + type CurrentTimestampTest struct { + ID string `gorm:"primary_key"` + TimeAt *time.Time `gorm:"type:datetime;not null;default:CURRENT_TIMESTAMP;unique"` + } + var err error + err = DB.Migrator().DropTable(&CurrentTimestampTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) + AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2")) +} + +func TestUniqueColumn(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + + type UniqueTest struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique"` + } + + type UniqueTest2 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:NULL"` + } + + type UniqueTest3 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:''"` + } + + type UniqueTest4 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:'123'"` + } + + var err error + err = DB.Migrator().DropTable(&UniqueTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // null -> null + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err := findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok := ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + // null -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // not trigger alert column + AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2")) + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + tidbSkip(t, "can't change column constraint") + + // null -> empty string + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, true, ok) + + // empty string -> 123 + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest4{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "123", value) + AssertEqual(t, true, ok) + + // 123 -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) +} + +func findColumnType(dest interface{}, columnName string) ( + foundColumn gorm.ColumnType, err error, +) { + columnTypes, err := DB.Migrator().ColumnTypes(dest) + if err != nil { + err = fmt.Errorf("ColumnTypes err:%v", err) + return + } + + for _, c := range columnTypes { + if c.Name() == columnName { + foundColumn = c + break + } + } + return +} + +func TestInvalidCachedPlanSimpleProtocol(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{}) + if err != nil { + t.Errorf("Open err:%v", err) + } + + type Object1 struct{} + type Object2 struct { + Field1 string + } + type Object3 struct { + Field2 string + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } +} + +func TestInvalidCachedPlanPrepareStmt(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true}) + if err != nil { + t.Errorf("Open err:%v", err) + } + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger = db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger = db.Logger.LogMode(logger.Silent) + } + + type Object1 struct { + ID uint + } + type Object2 struct { + ID uint + Field1 int `gorm:"type:int8"` + } + type Object3 struct { + ID uint + Field1 int `gorm:"type:int4"` + } + type Object4 struct { + ID uint + Field2 int + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = db.Table("objects").Create(&Object1{}).Error + if err != nil { + t.Errorf("create err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object2{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AlterColumn + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object3{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object4{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().DropColumn(&Object4{}, "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } +} + +func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { + type DiffType struct { + ID uint + Name string `gorm:"type:varchar(20)"` + } + + type DiffType1 struct { + ID uint + Name string `gorm:"type:text"` + } + + var err error + DB.Migrator().DropTable(&DiffType{}) + + err = DB.AutoMigrate(&DiffType{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + ct, err := findColumnType(&DiffType{}, "name") + if err != nil { + t.Errorf("findColumnType err:%v", err) + } + + AssertEqual(t, "varchar", strings.ToLower(ct.DatabaseTypeName())) + + err = DB.Table("diff_types").AutoMigrate(&DiffType1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&DiffType{}, "name") + if err != nil { + t.Errorf("findColumnType err:%v", err) + } + + AssertEqual(t, "text", strings.ToLower(ct.DatabaseTypeName())) +} + +func TestMigrateArrayTypeModel(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type ArrayTypeModel struct { + ID uint + Number string `gorm:"type:varchar(51);NOT NULL"` + TextArray []string `gorm:"type:text[];NOT NULL"` + NestedTextArray [][]string `gorm:"type:text[][]"` + NestedIntArray [][]int64 `gorm:"type:integer[3][3]"` + } + + var err error + DB.Migrator().DropTable(&ArrayTypeModel{}) + + err = DB.AutoMigrate(&ArrayTypeModel{}) + AssertEqual(t, nil, err) + + ct, err := findColumnType(&ArrayTypeModel{}, "number") + AssertEqual(t, nil, err) + AssertEqual(t, "varchar", ct.DatabaseTypeName()) + + ct, err = findColumnType(&ArrayTypeModel{}, "text_array") + AssertEqual(t, nil, err) + AssertEqual(t, "text[]", ct.DatabaseTypeName()) + + ct, err = findColumnType(&ArrayTypeModel{}, "nested_text_array") + AssertEqual(t, nil, err) + AssertEqual(t, "text[]", ct.DatabaseTypeName()) + + ct, err = findColumnType(&ArrayTypeModel{}, "nested_int_array") + AssertEqual(t, nil, err) + AssertEqual(t, "integer[]", ct.DatabaseTypeName()) +} + +type mockMigrator struct { + gorm.Migrator +} + +func (mm mockMigrator) AlterColumn(dst interface{}, field string) error { + err := mm.Migrator.AlterColumn(dst, field) + if err != nil { + return err + } + return fmt.Errorf("trigger alter column error, field: %s", field) +} + +func TestMigrateDonotAlterColumn(t *testing.T) { + wrapMockMigrator := func(m gorm.Migrator) mockMigrator { + return mockMigrator{ + Migrator: m, + } + } + m := DB.Migrator() + mockM := wrapMockMigrator(m) + + type NotTriggerUpdate struct { + ID uint + F1 uint16 + F2 uint32 + F3 int + F4 int64 + F5 string + F6 float32 + F7 float64 + F8 time.Time + F9 bool + F10 []byte + } + + var err error + err = mockM.DropTable(&NotTriggerUpdate{}) + AssertEqual(t, err, nil) + err = mockM.AutoMigrate(&NotTriggerUpdate{}) + AssertEqual(t, err, nil) + err = mockM.AutoMigrate(&NotTriggerUpdate{}) + AssertEqual(t, err, nil) +} + +func TestMigrateSameEmbeddedFieldName(t *testing.T) { + type UserStat struct { + GroundDestroyCount int + } + + type GameUser struct { + gorm.Model + StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"` + } + + type UserStat1 struct { + GroundDestroyCount string + } + + type GroundRate struct { + GroundDestroyCount int + } + + type GameUser1 struct { + gorm.Model + StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"` + GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"` + } + + DB.Migrator().DropTable(&GameUser{}) + err := DB.AutoMigrate(&GameUser{}) + AssertEqual(t, nil, err) + + err = DB.Table("game_users").AutoMigrate(&GameUser1{}) + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count") + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count") + AssertEqual(t, nil, err) +} + +func TestMigrateDefaultNullString(t *testing.T) { + if DB.Dialector.Name() == "sqlserver" { + // sqlserver driver treats NULL and 'NULL' the same + t.Skip("skip sqlserver") + } + + type NullModel struct { + ID uint + Content string `gorm:"default:null"` + } + + type NullStringModel struct { + ID uint + Content string `gorm:"default:'null'"` + } + + tableName := "null_string_model" + + DB.Migrator().DropTable(tableName) + + err := DB.Table(tableName).AutoMigrate(&NullModel{}) + AssertEqual(t, err, nil) + + // default null -> 'null' + err = DB.Table(tableName).AutoMigrate(&NullStringModel{}) + AssertEqual(t, err, nil) + + columnType, err := findColumnType(tableName, "content") + AssertEqual(t, err, nil) + + defVal, ok := columnType.DefaultValue() + AssertEqual(t, defVal, "null") + AssertEqual(t, ok, true) + + // default 'null' -> 'null' + session := DB.Session(&gorm.Session{Logger: Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE") { + t.Errorf("shouldn't execute: sql=%s", sql) + } + }, + }}) + err = session.Table(tableName).AutoMigrate(&NullStringModel{}) + AssertEqual(t, err, nil) + + columnType, err = findColumnType(tableName, "content") + AssertEqual(t, err, nil) + + defVal, ok = columnType.DefaultValue() + AssertEqual(t, defVal, "null") + AssertEqual(t, ok, true) + + // default 'null' -> null + err = DB.Table(tableName).AutoMigrate(&NullModel{}) + AssertEqual(t, err, nil) + + columnType, err = findColumnType(tableName, "content") + AssertEqual(t, err, nil) + + defVal, ok = columnType.DefaultValue() + AssertEqual(t, defVal, "") + AssertEqual(t, ok, false) +} + +func TestMigrateMySQLWithCustomizedTypes(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type MyTable struct { + Def string `gorm:"size:512;index:idx_def,unique"` + Abc string `gorm:"size:65000000"` + } + + DB.Migrator().DropTable("my_tables") + + sql := "CREATE TABLE `my_tables` (`def` varchar(512),`abc` longtext,UNIQUE INDEX `idx_def` (`def`))" + if err := DB.Exec(sql).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + session := DB.Session(&gorm.Session{Logger: Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE") { + t.Errorf("shouldn't execute: sql=%s", sql) + } + }, + }}) + + if err := session.AutoMigrate(&MyTable{}); err != nil { + t.Errorf("Failed, got error: %v", err) + } +} + +func TestMigrateIgnoreRelations(t *testing.T) { + type RelationModel1 struct { + ID uint + } + type RelationModel2 struct { + ID uint + } + type RelationModel3 struct { + ID uint + RelationModel1ID uint + RelationModel1 *RelationModel1 + RelationModel2ID uint + RelationModel2 *RelationModel2 `gorm:"-:migration"` + } + + var err error + _ = DB.Migrator().DropTable(&RelationModel1{}, &RelationModel2{}, &RelationModel3{}) + + tx := DB.Session(&gorm.Session{}) + tx.IgnoreRelationshipsWhenMigrating = true + + err = tx.AutoMigrate(&RelationModel3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // RelationModel3 should be existed + _, err = findColumnType(&RelationModel3{}, "id") + AssertEqual(t, nil, err) + + // RelationModel1 should not be existed + _, err = findColumnType(&RelationModel1{}, "id") + if err == nil { + t.Errorf("RelationModel1 should not be migrated") + } + + // RelationModel2 should not be existed + _, err = findColumnType(&RelationModel2{}, "id") + if err == nil { + t.Errorf("RelationModel2 should not be migrated") + } + + tx.IgnoreRelationshipsWhenMigrating = false + + err = tx.AutoMigrate(&RelationModel3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // RelationModel3 should be existed + _, err = findColumnType(&RelationModel3{}, "id") + AssertEqual(t, nil, err) + + // RelationModel1 should be existed + _, err = findColumnType(&RelationModel1{}, "id") + AssertEqual(t, nil, err) + + // RelationModel2 should not be existed + _, err = findColumnType(&RelationModel2{}, "id") + if err == nil { + t.Errorf("RelationModel2 should not be migrated") + } +} + +func TestMigrateView(t *testing.T) { + DB.Save(GetUser("joins-args-db", Config{Pets: 2})) + + if err := DB.Migrator().CreateView("invalid_users_pets", gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { + t.Fatalf("no view should be created, got %v", err) + } + + query := DB.Model(&User{}). + Select("users.id as users_id, users.name as users_name, pets.id as pets_id, pets.name as pets_name"). + Joins("inner join pets on pets.user_id = users.id") + + if err := DB.Migrator().CreateView("users_pets", gorm.ViewOption{Query: query}); err != nil { + t.Fatalf("Failed to crate view, got %v", err) + } + + var count int64 + if err := DB.Table("users_pets").Count(&count).Error; err != nil { + t.Fatalf("should found created view") + } + + if err := DB.Migrator().DropView("users_pets"); err != nil { + t.Fatalf("Failed to drop view, got %v", err) + } + + query = DB.Model(&User{}).Where("age > ?", 20) + if err := DB.Migrator().CreateView("users_view", gorm.ViewOption{Query: query}); err != nil { + t.Fatalf("Failed to crate view, got %v", err) + } + if err := DB.Migrator().DropView("users_view"); err != nil { + t.Fatalf("Failed to drop view, got %v", err) + } +} + +func TestMigrateExistingBoolColumnPG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type ColumnStruct struct { + gorm.Model + Name string + StringBool string + SmallintBool int `gorm:"type:smallint"` + } + + type ColumnStruct2 struct { + gorm.Model + Name string + StringBool bool // change existing boolean column from string to boolean + SmallintBool bool // change existing boolean column from smallint or other to boolean + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "string_bool": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + case "smallint_bool": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + } + } + } +} diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go new file mode 100644 index 00000000..4a7ab9f6 --- /dev/null +++ b/tests/multi_primary_keys_test.go @@ -0,0 +1,448 @@ +package tests_test + +import ( + "reflect" + "sort" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` +} + +type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string + Blogs []*Blog `gorm:"many2many:blog_tags"` +} + +func compareTags(tags []Tag, contents []string) bool { + var tagContents []string + for _, tag := range tags { + tagContents = append(tagContents, tag.Value) + } + sort.Strings(tagContents) + sort.Strings(contents) + return reflect.DeepEqual(tagContents, contents) +} + +func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") + } + + if name := DB.Dialector.Name(); name == "postgres" { + stmt := gorm.Statement{DB: DB} + stmt.Parse(&Blog{}) + stmt.Schema.LookUpField("ID").Unique = true + stmt.Parse(&Tag{}) + stmt.Schema.LookUpField("ID").Unique = true + // postgers only allow unique constraint matching given keys + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + Tags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + + DB.Save(&blog) + if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { + t.Fatalf("Blog should has two tags") + } + + // Append + tag3 := &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) + + if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if count := DB.Model(&blog).Association("Tags").Count(); count != 3 { + t.Fatalf("Blog should has 3 tags after Append, got %v", count) + } + + var tags []Tag + DB.Model(&blog).Association("Tags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + var blog1 Blog + DB.Preload("Tags").Find(&blog1) + if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + // Replace + tag5 := &Tag{Locale: "ZH", Value: "tag5"} + tag6 := &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog).Association("Tags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Association("Tags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("Tags").Count() != 2 { + t.Fatalf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("Tags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Association("Tags").Find(&tags3) + if !compareTags(tags3, []string{"tag6"}) { + t.Fatalf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("Tags").Count() != 1 { + t.Fatalf("Blog should has three tags after Delete") + } + + DB.Model(&blog).Association("Tags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Association("Tags").Find(&tags4) + if !compareTags(tags4, []string{"tag6"}) { + t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog).Association("Tags").Clear() + if DB.Model(&blog).Association("Tags").Count() != 0 { + t.Fatalf("All tags should be cleared") + } +} + +func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") + } + + if name := DB.Dialector.Name(); name == "postgres" { + t.Skip("skip postgres due to it only allow unique constraint matching given keys") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + SharedTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { + t.Fatalf("Blog should has two tags") + } + + // Append + tag3 := &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) + if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("SharedTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + var tags []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + var blog1 Blog + DB.Preload("SharedTags").Find(&blog1) + if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + tag4 := &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("SharedTags").Append(tag4) + + DB.Model(&blog).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Fatalf("Should find 3 tags") + } + + // Replace + tag5 := &Tag{Locale: "ZH", Value: "tag5"} + tag6 := &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 2 { + t.Fatalf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("SharedTags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags3) + if !compareTags(tags3, []string{"tag6"}) { + t.Fatalf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 1 { + t.Fatalf("Blog should has three tags after Delete") + } + + DB.Model(&blog2).Association("SharedTags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags4) + if !compareTags(tags4, []string{"tag6"}) { + t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog2).Association("SharedTags").Clear() + if DB.Model(&blog).Association("SharedTags").Count() != 0 { + t.Fatalf("All tags should be cleared") + } +} + +func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") + } + + if name := DB.Dialector.Name(); name == "postgres" { + t.Skip("skip postgres due to it only allow unique constraint matching given keys") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + LocaleTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + // Append + tag3 := &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) + if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog should has 0 tags after ZH Blog Append") + } + + var tags []Tag + DB.Model(&blog).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags) + if len(tags) != 0 { + t.Fatalf("Should find 0 tags for EN Blog") + } + + var blog1 Blog + DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) + if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + tag4 := &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("LocaleTags").Append(tag4) + + DB.Model(&blog).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags for EN Blog") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag4"}) { + t.Fatalf("Should find 1 tags for EN Blog") + } + + // Replace + tag5 := &Tag{Locale: "ZH", Value: "tag5"} + tag6 := &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) + + var tags2 []Tag + DB.Model(&blog).Association("LocaleTags").Find(&tags2) + if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") + } + + var blog11 Blog + DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) + if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + var blog21 Blog + DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) + if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { + t.Fatalf("EN Blog's tags should be changed after Replace") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after Replace") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Fatalf("EN Blog should has two tags after Replace") + } + + // Delete + DB.Model(&blog).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Fatalf("EN Blog should has two tags after ZH Blog Delete with EN's tag") + } + + DB.Model(&blog2).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after EN Blog Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { + t.Fatalf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") + } + + // Clear + DB.Model(&blog2).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog's tags should not be cleared when clear EN Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog's tags should be cleared when clear EN Blog's tags") + } + + DB.Model(&blog).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 0 { + t.Fatalf("ZH Blog's tags should be cleared when clear ZH Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog's tags should be cleared") + } +} + +func TestCompositePrimaryKeysAssociations(t *testing.T) { + type Label struct { + BookID *uint `gorm:"primarykey"` + Name string `gorm:"primarykey"` + Value string + } + + type Book struct { + ID int + Name string + Labels []Label + } + + DB.Migrator().DropTable(&Label{}, &Book{}) + if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { + t.Fatalf("failed to migrate, got %v", err) + } + + book := Book{ + Name: "my book", + Labels: []Label{ + {Name: "region", Value: "emea"}, + }, + } + + DB.Create(&book) + + var result Book + if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil { + t.Fatalf("failed to preload, got error %v", err) + } + + AssertEqual(t, book, result) +} diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go new file mode 100644 index 00000000..a3a25f7b --- /dev/null +++ b/tests/named_argument_test.go @@ -0,0 +1,82 @@ +package tests_test + +import ( + "database/sql" + "errors" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestNamedArg(t *testing.T) { + type NamedUser struct { + gorm.Model + Name1 string + Name2 string + Name3 string + } + + DB.Migrator().DropTable(&NamedUser{}) + DB.AutoMigrate(&NamedUser{}) + + namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"} + DB.Create(&namedUser) + + var result NamedUser + DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")) + + AssertEqual(t, result, namedUser) + + var result2 NamedUser + DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2) + + AssertEqual(t, result2, namedUser) + + var result3 NamedUser + DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3) + + AssertEqual(t, result3, namedUser) + + var result4 NamedUser + if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result4, namedUser) + + if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + namedUser.Name1 = "jinzhu-new" + namedUser.Name2 = "jinzhu-new2" + namedUser.Name3 = "jinzhu-new" + + var result5 NamedUser + if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result5, namedUser) + + var result6 NamedUser + if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name + AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result6, namedUser) + + var result7 NamedUser + if err := DB.Where("name1 = @name OR name2 = @name", sql.Named("name", "jinzhu-new")).Where("name3 = 'jinzhu-new3'").First(&result7).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should return record not found error, but got %v", err) + } + + DB.Delete(&namedUser) + + var result8 NamedUser + if err := DB.Where("name1 = @name OR name2 = @name", map[string]interface{}{"name": "jinzhu-new"}).First(&result8).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should return record not found error, but got %v", err) + } +} diff --git a/tests/named_polymorphic_test.go b/tests/named_polymorphic_test.go new file mode 100644 index 00000000..956f3a7e --- /dev/null +++ b/tests/named_polymorphic_test.go @@ -0,0 +1,147 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +type Hamster struct { + Id int + Name string + PreferredToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_preferred"` + OtherToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_other"` +} + +func TestNamedPolymorphic(t *testing.T) { + DB.Migrator().DropTable(&Hamster{}) + DB.AutoMigrate(&Hamster{}) + + hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} + DB.Save(&hamster) + + hamster2 := Hamster{} + DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) + + if hamster2.PreferredToy.ID != hamster.PreferredToy.ID || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { + t.Errorf("Hamster's preferred toy failed to preload") + } + + if hamster2.OtherToy.ID != hamster.OtherToy.ID || hamster2.OtherToy.Name != hamster.OtherToy.Name { + t.Errorf("Hamster's other toy failed to preload") + } + + // clear to omit Toy.ID in count + hamster2.PreferredToy = Toy{} + hamster2.OtherToy = Toy{} + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's preferred toy count should be 1") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy count should be 1") + } + + // Query + hamsterToy := Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.PreferredToy.Name { + t.Errorf("Should find has one polymorphic association") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.OtherToy.Name { + t.Errorf("Should find has one polymorphic association") + } + + // Append + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + // Replace + DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ + Name: "bike 3", + }) + + DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ + Name: "treadmill 3", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + // Clear + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + + if DB.Model(&hamster).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + + DB.Model(&hamster).Association("PreferredToy").Clear() + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { + t.Errorf("Hamster's preferred toy should be cleared with Clear") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy should be still available") + } + + DB.Model(&hamster).Association("OtherToy").Clear() + if DB.Model(&hamster).Association("OtherToy").Count() != 0 { + t.Errorf("Hamster's other toy should be cleared with Clear") + } +} diff --git a/tests/non_std_test.go b/tests/non_std_test.go new file mode 100644 index 00000000..8ae42691 --- /dev/null +++ b/tests/non_std_test.go @@ -0,0 +1,61 @@ +package tests_test + +import ( + "testing" + "time" +) + +type Animal struct { + Counter uint64 `gorm:"primary_key:yes"` + Name string `gorm:"DEFAULT:'galeone'"` + From string // test reserved sql keyword as field name + Age *time.Time + unexported string // unexported value + CreatedAt time.Time + UpdatedAt time.Time +} + +func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) { + DB.Migrator().DropTable(&Animal{}) + if err := DB.AutoMigrate(&Animal{}); err != nil { + t.Fatalf("no error should happen when migrate but got %v", err) + } + + animal := Animal{Name: "Ferdinand"} + DB.Save(&animal) + updatedAt1 := animal.UpdatedAt + + DB.Save(&animal).Update("name", "Francis") + if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdatedAt should be updated") + } + + var animals []Animal + DB.Find(&animals) + if count := DB.Model(Animal{}).Where("1=1").Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { + t.Error("RowsAffected should be correct when do batch update") + } + + animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) + DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched + DB.First(&animal, animal.Counter) + if animal.Name != "galeone" { + t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) + } + + // When changing a field with a default value, the change must occur + animal.Name = "amazing horse" + DB.Save(&animal) + DB.First(&animal, animal.Counter) + if animal.Name != "amazing horse" { + t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) + } + + // When changing a field with a default value with blank value + animal.Name = "" + DB.Save(&animal) + DB.First(&animal, animal.Counter) + if animal.Name != "" { + t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) + } +} diff --git a/tests/postgres_test.go b/tests/postgres_test.go new file mode 100644 index 00000000..44cac6bf --- /dev/null +++ b/tests/postgres_test.go @@ -0,0 +1,256 @@ +package tests_test + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/lib/pq" + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestPostgresReturningIDWhichHasStringType(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Yasuo struct { + ID string `gorm:"default:gen_random_uuid()"` + Name string + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Yasuo{}) + + if err := DB.AutoMigrate(&Yasuo{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + yasuo := Yasuo{Name: "jinzhu"} + if err := DB.Create(&yasuo).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } + + if yasuo.ID == "" { + t.Fatal("should be able to has ID, but got zero value") + } + + var result Yasuo + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + yasuo.Name = "jinzhu1" + if err := DB.Save(&yasuo).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } +} + +func TestPostgres(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Harumph struct { + gorm.Model + Name string `gorm:"check:name_checker,name <> ''"` + Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` + Things pq.StringArray `gorm:"type:text[]"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Harumph{}) + + if err := DB.AutoMigrate(&Harumph{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + harumph := Harumph{} + if err := DB.Create(&harumph).Error; err == nil { + t.Fatalf("should failed to create data, name can't be blank") + } + + harumph = Harumph{Name: "jinzhu"} + if err := DB.Create(&harumph).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } + + var result Harumph + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + harumph.Name = "jinzhu1" + if err := DB.Save(&harumph).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } + + DB.Migrator().DropTable("log_usage") + + if err := DB.Exec(` +CREATE TABLE public.log_usage ( + log_id bigint NOT NULL +); + +ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY ( + SEQUENCE NAME public.log_usage_log_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + `).Error; err != nil { + t.Fatalf("failed to create table, got error %v", err) + } + + columns, err := DB.Migrator().ColumnTypes("log_usage") + if err != nil { + t.Fatalf("failed to get columns, got error %v", err) + } + + hasLogID := false + for _, column := range columns { + if column.Name() == "log_id" { + hasLogID = true + autoIncrement, ok := column.AutoIncrement() + if !ok || !autoIncrement { + t.Fatalf("column log_id should be auto incrementment") + } + } + } + + if !hasLogID { + t.Fatalf("failed to found column log_id") + } +} + +type Post struct { + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` + Title string + Categories []*Category `gorm:"Many2Many:post_categories"` +} + +type Category struct { + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` + Title string + Posts []*Post `gorm:"Many2Many:post_categories"` +} + +func TestMany2ManyWithDefaultValueUUID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil { + t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err) + } + + DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories") + DB.AutoMigrate(&Post{}, &Category{}) + + post := Post{ + Title: "Hello World", + Categories: []*Category{ + {Title: "Coding"}, + {Title: "Golang"}, + }, + } + + if err := DB.Create(&post).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } +} + +func TestPostgresOnConstraint(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Thing struct { + gorm.Model + SomeID string + OtherID string + Data string + } + + DB.Migrator().DropTable(&Thing{}) + DB.Migrator().CreateTable(&Thing{}) + if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil { + t.Error(err) + } + + thing := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something", + } + + DB.Create(&thing) + + thing2 := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something else", + } + + result := DB.Clauses(clause.OnConflict{ + OnConstraint: "some_id_other_id_unique", + UpdateAll: true, + }).Create(&thing2) + if result.Error != nil { + t.Errorf("creating second thing: %v", result.Error) + } + + var things []Thing + if err := DB.Find(&things).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if len(things) > 1 { + t.Errorf("expected 1 thing got more") + } +} + +type CompanyNew struct { + ID int + Name int +} + +func TestAlterColumnDataType(t *testing.T) { + DB.AutoMigrate(Company{}) + + if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil { + t.Fatalf("failed to alter column from string to int, got error %v", err) + } + + DB.AutoMigrate(Company{}) +} diff --git a/preload_test.go b/tests/preload_suits_test.go similarity index 77% rename from preload_test.go rename to tests/preload_suits_test.go index 8c56a8ac..b5b6a70f 100644 --- a/preload_test.go +++ b/tests/preload_suits_test.go @@ -1,97 +1,19 @@ -package gorm_test +package tests_test import ( "database/sql" "encoding/json" - "os" "reflect" + "sort" + "sync/atomic" "testing" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) -func getPreloadUser(name string) *User { - return getPreparedUser(name, "Preload") -} - -func checkUserHasPreloadData(user User, t *testing.T) { - u := getPreloadUser(user.Name) - if user.BillingAddress.Address1 != u.BillingAddress.Address1 { - t.Error("Failed to preload user's BillingAddress") - } - - if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { - t.Error("Failed to preload user's ShippingAddress") - } - - if user.CreditCard.Number != u.CreditCard.Number { - t.Error("Failed to preload user's CreditCard") - } - - if user.Company.Name != u.Company.Name { - t.Error("Failed to preload user's Company") - } - - if len(user.Emails) != len(u.Emails) { - t.Error("Failed to preload user's Emails") - } else { - var found int - for _, e1 := range u.Emails { - for _, e2 := range user.Emails { - if e1.Email == e2.Email { - found++ - break - } - } - } - if found != len(u.Emails) { - t.Error("Failed to preload user's email details") - } - } -} - -func TestPreload(t *testing.T) { - user1 := getPreloadUser("user1") - DB.Save(user1) - - preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) - - user2 := getPreloadUser("user2") - DB.Save(user2) - - user3 := getPreloadUser("user3") - DB.Save(user3) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - checkUserHasPreloadData(user, t) - } - - var users2 []*User - preloadDB.Find(&users2) - - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } - - var users3 []*User - preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) - - for _, user := range users3 { - if user.Name == user3.Name { - if len(user.Emails) != 1 { - t.Errorf("should only preload one emails for user3 when with condition") - } - } else if len(user.Emails) != 0 { - t.Errorf("should not preload any emails for other users when with condition") - } - } +func toJSONString(v interface{}) []byte { + r, _ := json.Marshal(v) + return r } func TestNestedPreload1(t *testing.T) { @@ -112,10 +34,8 @@ func TestNestedPreload1(t *testing.T) { Level2 Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -133,7 +53,7 @@ func TestNestedPreload1(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + if err := DB.Preload("Level2").Preload("Level2.Level1").First(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -156,10 +76,8 @@ func TestNestedPreload2(t *testing.T) { Level2s []Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -210,10 +128,8 @@ func TestNestedPreload3(t *testing.T) { Level2s []Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -255,10 +171,8 @@ func TestNestedPreload4(t *testing.T) { Level2 Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -303,10 +217,8 @@ func TestNestedPreload5(t *testing.T) { Level2 Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -348,10 +260,8 @@ func TestNestedPreload6(t *testing.T) { Level2s []Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -422,10 +332,8 @@ func TestNestedPreload7(t *testing.T) { Level2s []Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -478,10 +386,8 @@ func TestNestedPreload8(t *testing.T) { Level2 Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -529,9 +435,9 @@ func TestNestedPreload9(t *testing.T) { Level1 struct { ID uint Value string - Level2ID uint - Level2_1ID uint - Level0s []Level0 + Level2ID *uint + Level2_1ID *uint + Level0s []Level0 `json:",omitempty"` } Level2 struct { ID uint @@ -540,7 +446,7 @@ func TestNestedPreload9(t *testing.T) { } Level2_1 struct { ID uint - Level1s []Level1 + Level1s []Level1 `json:",omitempty"` Level3ID uint } Level3 struct { @@ -550,12 +456,8 @@ func TestNestedPreload9(t *testing.T) { Level2_1 Level2_1 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level2_1{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level0{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}); err != nil { t.Error(err) } @@ -592,8 +494,14 @@ func TestNestedPreload9(t *testing.T) { }, Level2_1: Level2_1{ Level1s: []Level1{ - {Value: "value3-3"}, - {Value: "value4-4"}, + { + Value: "value3-3", + Level0s: []Level0{}, + }, + { + Value: "value4-4", + Level0s: []Level0{}, + }, }, }, } @@ -606,7 +514,7 @@ func TestNestedPreload9(t *testing.T) { t.Error(err) } - if !reflect.DeepEqual(got, want) { + if string(toJSONString(got)) != string(toJSONString(want)) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } @@ -619,7 +527,7 @@ type LevelA1 struct { type LevelA2 struct { ID uint Value string - LevelA3s []*LevelA3 + LevelA3s []*LevelA3 `json:",omitempty"` } type LevelA3 struct { @@ -632,11 +540,8 @@ type LevelA3 struct { } func TestNestedPreload10(t *testing.T) { - DB.DropTableIfExists(&LevelA3{}) - DB.DropTableIfExists(&LevelA2{}) - DB.DropTableIfExists(&LevelA1{}) - - if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { + DB.Migrator().DropTable(&LevelA3{}, &LevelA2{}, &LevelA1{}) + if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}); err != nil { t.Error(err) } @@ -656,7 +561,8 @@ func TestNestedPreload10(t *testing.T) { }, }, { - Value: "bar 2", + Value: "bar 2", + LevelA3s: []*LevelA3{}, }, } for _, levelA2 := range want { @@ -670,7 +576,7 @@ func TestNestedPreload10(t *testing.T) { t.Error(err) } - if !reflect.DeepEqual(got, want) { + if !reflect.DeepEqual(toJSONString(got), toJSONString(want)) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } @@ -691,14 +597,12 @@ type LevelB3 struct { Value string LevelB1ID sql.NullInt64 LevelB1 *LevelB1 - LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` + LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s" json:",omitempty"` } func TestNestedPreload11(t *testing.T) { - DB.DropTableIfExists(&LevelB2{}) - DB.DropTableIfExists(&LevelB3{}) - DB.DropTableIfExists(&LevelB1{}) - if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { + DB.Migrator().DropTable(&LevelB3{}, &LevelB2{}, &LevelB1{}) + if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}); err != nil { t.Error(err) } @@ -710,6 +614,7 @@ func TestNestedPreload11(t *testing.T) { levelB3 := &LevelB3{ Value: "bar", LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, + LevelB2s: []*LevelB2{}, } if err := DB.Create(levelB3).Error; err != nil { t.Error(err) @@ -722,7 +627,7 @@ func TestNestedPreload11(t *testing.T) { t.Error(err) } - if !reflect.DeepEqual(got, want) { + if !reflect.DeepEqual(toJSONString(got), toJSONString(want)) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } @@ -747,10 +652,8 @@ type LevelC3 struct { } func TestNestedPreload12(t *testing.T) { - DB.DropTableIfExists(&LevelC2{}) - DB.DropTableIfExists(&LevelC3{}) - DB.DropTableIfExists(&LevelC1{}) - if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}).Error; err != nil { + DB.Migrator().DropTable(&LevelC3{}, &LevelC2{}, &LevelC1{}) + if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}); err != nil { t.Error(err) } @@ -789,8 +692,8 @@ func TestNestedPreload12(t *testing.T) { } func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { - return + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } type ( @@ -807,11 +710,10 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { } ) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") + DB.Migrator().DropTable(&Level2{}, &Level1{}) + DB.Migrator().DropTable("levels") - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -898,12 +800,10 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) { } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + DB.Migrator().DropTable("levels") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1000,13 +900,9 @@ func TestNestedManyToManyPreload(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, "level1_level2", "level2_level3") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1042,7 +938,7 @@ func TestNestedManyToManyPreload(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + if err := DB.Preload("Level2s.Level1s").First(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -1066,12 +962,10 @@ func TestNestedManyToManyPreload2(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + DB.Migrator().DropTable("level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1099,7 +993,7 @@ func TestNestedManyToManyPreload2(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + if err := DB.Preload("Level2.Level1s").First(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -1123,12 +1017,9 @@ func TestNestedManyToManyPreload3(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, "level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1162,7 +1053,7 @@ func TestNestedManyToManyPreload3(t *testing.T) { } for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(want).Error; err != nil { t.Error(err) } } @@ -1198,12 +1089,10 @@ func TestNestedManyToManyPreload3ForStruct(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + DB.Migrator().DropTable("level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1237,7 +1126,7 @@ func TestNestedManyToManyPreload3ForStruct(t *testing.T) { } for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(want).Error; err != nil { t.Error(err) } } @@ -1278,12 +1167,8 @@ func TestNestedManyToManyPreload4(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level4{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") + DB.Migrator().DropTable("level1_level2", "level2_level3") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) dummy := Level1{ Value: "Level1", @@ -1298,7 +1183,7 @@ func TestNestedManyToManyPreload4(t *testing.T) { }}, } - if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1325,11 +1210,9 @@ func TestManyToManyPreloadForPointer(t *testing.T) { } ) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") + DB.Migrator().DropTable("levels", &Level2{}, &Level1{}) - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1411,16 +1294,13 @@ func TestNilPointerSlice(t *testing.T) { Level1 struct { ID uint Value string - Level2ID uint + Level2ID *uint Level2 *Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1442,7 +1322,7 @@ func TestNilPointerSlice(t *testing.T) { Level2: nil, } if err := DB.Save(&want2).Error; err != nil { - t.Error(err) + t.Fatalf("Got error %v", err) } var got []Level1 @@ -1455,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) { } if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) + t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) } if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { @@ -1484,12 +1364,9 @@ func TestNilPointerSlice2(t *testing.T) { } ) - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) - if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)).Error; err != nil { + if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)); err != nil { t.Error(err) } @@ -1519,7 +1396,7 @@ func TestPrefixedPreloadDuplication(t *testing.T) { Level3 struct { ID uint Name string - Level4s []*Level4 + Level4s []*Level4 `json:",omitempty"` } Level2 struct { ID uint @@ -1535,12 +1412,9 @@ func TestPrefixedPreloadDuplication(t *testing.T) { } ) - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) - if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)).Error; err != nil { + if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)); err != nil { t.Error(err) } @@ -1586,12 +1460,52 @@ func TestPrefixedPreloadDuplication(t *testing.T) { t.Error(err) } + for _, level1 := range append(got, want...) { + sort.Slice(level1.Level2.Level3.Level4s, func(i, j int) bool { + return level1.Level2.Level3.Level4s[i].ID > level1.Level2.Level3.Level4s[j].ID + }) + } + if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } -func toJSONString(v interface{}) []byte { - r, _ := json.MarshalIndent(v, "", " ") - return r +func TestPreloadManyToManyCallbacks(t *testing.T) { + type ( + Level2 struct { + ID uint + Name string + } + Level1 struct { + ID uint + Name string + Level2s []Level2 `gorm:"many2many:level1_level2s"` + } + ) + + DB.Migrator().DropTable("level1_level2s", &Level2{}, &Level1{}) + + if err := DB.AutoMigrate(new(Level1), new(Level2)); err != nil { + t.Error(err) + } + + lvl := Level1{ + Name: "l1", + Level2s: []Level2{ + {Name: "l2-1"}, {Name: "l2-2"}, + }, + } + DB.Save(&lvl) + + var called int64 + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(_ *gorm.DB) { + atomic.AddInt64(&called, 1) + }) + + DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) + + if called != 3 { + t.Errorf("Wanted callback to be called 3 times but got %d", called) + } } diff --git a/tests/preload_test.go b/tests/preload_test.go new file mode 100644 index 00000000..7304e350 --- /dev/null +++ b/tests/preload_test.go @@ -0,0 +1,446 @@ +package tests_test + +import ( + "encoding/json" + "regexp" + "sort" + "strconv" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestPreloadWithAssociations(t *testing.T) { + user := *GetUser("preload_with_associations", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + var user2 User + DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + user3 := *GetUser("preload_with_associations_new", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) +} + +func TestNestedPreload(t *testing.T) { + user := *GetUser("nested_preload", Config{Pets: 2}) + + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} + } + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var user2 User + DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + var user3 User + DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) + + var user4 *User + DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) + CheckUser(t, *user4, user) +} + +func TestNestedPreloadForSlice(t *testing.T) { + users := []User{ + *GetUser("slice_nested_preload_1", Config{Pets: 2}), + *GetUser("slice_nested_preload_2", Config{Pets: 0}), + *GetUser("slice_nested_preload_3", Config{Pets: 3}), + } + + for _, user := range users { + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: user.Name + "_toy_nested_preload_" + strconv.Itoa(idx+1)} + } + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Pets.Toy").Find(&users2, "id IN ?", userIDs) + + for idx, user := range users2 { + CheckUser(t, user, users[idx]) + } +} + +func TestPreloadWithConds(t *testing.T) { + users := []User{ + *GetUser("slice_nested_preload_1", Config{Account: true}), + *GetUser("slice_nested_preload_2", Config{Account: false}), + *GetUser("slice_nested_preload_3", Config{Account: true}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Account", clause.Eq{Column: "number", Value: users[0].Account.Number}).Find(&users2, "id IN ?", userIDs) + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for idx, user := range users2[1:2] { + if user.Account.Number != "" { + t.Errorf("No account should found for user %v but got %v", idx+2, user.Account.Number) + } + } + + CheckUser(t, users2[0], users[0]) + + var users3 []User + if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB { + return tx.Table("accounts AS a").Select("a.*") + }).Find(&users3, "id IN ?", userIDs).Error; err != nil { + t.Errorf("failed to query, got error %v", err) + } + sort.Slice(users3, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for i, u := range users3 { + CheckUser(t, u, users[i]) + } + + var user4 User + DB.Delete(&users3[0].Account) + + if err := DB.Preload(clause.Associations).Take(&user4, "id = ?", users3[0].ID).Error; err != nil || user4.Account.ID != 0 { + t.Errorf("failed to query, got error %v, account: %#v", err, user4.Account) + } + + if err := DB.Preload(clause.Associations, func(tx *gorm.DB) *gorm.DB { + return tx.Unscoped() + }).Take(&user4, "id = ?", users3[0].ID).Error; err != nil || user4.Account.ID == 0 { + t.Errorf("failed to query, got error %v, account: %#v", err, user4.Account) + } +} + +func TestNestedPreloadWithConds(t *testing.T) { + users := []User{ + *GetUser("slice_nested_preload_1", Config{Pets: 2}), + *GetUser("slice_nested_preload_2", Config{Pets: 0}), + *GetUser("slice_nested_preload_3", Config{Pets: 3}), + } + + for _, user := range users { + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: user.Name + "_toy_nested_preload_" + strconv.Itoa(idx+1)} + } + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Pets.Toy", "name like ?", `%preload_3`).Find(&users2, "id IN ?", userIDs) + + for idx, user := range users2[0:2] { + for _, pet := range user.Pets { + if pet.Toy.Name != "" { + t.Errorf("No toy should for user %v's pet %v but got %v", idx+1, pet.Name, pet.Toy.Name) + } + } + } + + if len(users2[2].Pets) != 3 { + t.Errorf("Invalid pet toys found for user 3 got %v", len(users2[2].Pets)) + } else { + sort.Slice(users2[2].Pets, func(i, j int) bool { + return users2[2].Pets[i].ID < users2[2].Pets[j].ID + }) + + for _, pet := range users2[2].Pets[0:2] { + if pet.Toy.Name != "" { + t.Errorf("No toy should for user %v's pet %v but got %v", 3, pet.Name, pet.Toy.Name) + } + } + + CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) + } +} + +func TestPreloadEmptyData(t *testing.T) { + user := *GetUser("user_without_associations", Config{}) + DB.Create(&user) + + DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name) + + if r, err := json.Marshal(&user); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } + + var results []User + DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name) + + if r, err := json.Marshal(&results); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } +} + +func TestPreloadGoroutine(t *testing.T) { + var wg sync.WaitGroup + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + var user2 []User + tx := DB.Where("id = ?", 1).Session(&gorm.Session{}) + + if err := tx.Preload("Team").Find(&user2).Error; err != nil { + t.Error(err) + } + }() + } + wg.Wait() +} + +func TestPreloadWithDiffModel(t *testing.T) { + user := *GetUser("preload_with_diff_model", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var result struct { + Something string + User + } + + DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select( + "users.*, 'yo' as something").First(&result, "name = ?", user.Name) + + CheckUser(t, user, result.User) +} + +func TestNestedPreloadWithUnscoped(t *testing.T) { + user := *GetUser("nested_preload", Config{Pets: 1}) + pet := user.Pets[0] + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(1)} + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(2)} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var user2 User + DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + DB.Delete(&pet) + + var user3 User + DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) + if len(user3.Pets) != 0 { + t.Fatalf("User.Pet[0] was deleted and should not exist.") + } + + var user4 *User + DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) + if len(user4.Pets) != 0 { + t.Fatalf("User.Pet[0] was deleted and should not exist.") + } + + var user5 User + DB.Unscoped().Preload(clause.Associations+"."+clause.Associations).Find(&user5, "id = ?", user.ID) + CheckUserUnscoped(t, user5, user) + + var user6 *User + DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) + CheckUserUnscoped(t, *user6, user) +} + +func TestEmbedPreload(t *testing.T) { + type Country struct { + ID int `gorm:"primaryKey"` + Name string + } + type EmbeddedAddress struct { + ID int + Name string + CountryID *int + Country *Country + } + type NestedAddress struct { + EmbeddedAddress + } + type Org struct { + ID int + PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"` + VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"` + AddressID *int + Address *EmbeddedAddress + NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + } + + DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{}) + DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{}) + + org := Org{ + PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}}, + VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}}, + Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}}, + NestedAddress: NestedAddress{ + EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}}, + }, + } + if err := DB.Create(&org).Error; err != nil { + t.Errorf("failed to create org, got err: %v", err) + } + + tests := []struct { + name string + preloads map[string][]interface{} + expect Org + }{ + { + name: "address country", + preloads: map[string][]interface{}{"Address.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: EmbeddedAddress{ + ID: org.PostalAddress.ID, + Name: org.PostalAddress.Name, + CountryID: org.PostalAddress.CountryID, + Country: nil, + }, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: org.Address, + NestedAddress: NestedAddress{EmbeddedAddress{ + ID: org.NestedAddress.ID, + Name: org.NestedAddress.Name, + CountryID: org.NestedAddress.CountryID, + Country: nil, + }}, + }, + }, { + name: "postal address country", + preloads: map[string][]interface{}{"PostalAddress.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: org.PostalAddress, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: nil, + NestedAddress: NestedAddress{EmbeddedAddress{ + ID: org.NestedAddress.ID, + Name: org.NestedAddress.Name, + CountryID: org.NestedAddress.CountryID, + Country: nil, + }}, + }, + }, { + name: "nested address country", + preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: EmbeddedAddress{ + ID: org.PostalAddress.ID, + Name: org.PostalAddress.Name, + CountryID: org.PostalAddress.CountryID, + Country: nil, + }, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: nil, + NestedAddress: org.NestedAddress, + }, + }, { + name: "associations", + preloads: map[string][]interface{}{ + clause.Associations: {}, + // clause.Associations won’t preload nested associations + "Address.Country": {}, + }, + expect: org, + }, + } + + DB = DB.Debug() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual := Org{} + tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{}) + for name, args := range test.preloads { + tx = tx.Preload(name, args...) + } + if err := tx.Find(&actual).Error; err != nil { + t.Errorf("failed to find org, got err: %v", err) + } + AssertEqual(t, actual, test.expect) + }) + } +} diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go new file mode 100644 index 00000000..64baa01b --- /dev/null +++ b/tests/prepared_stmt_test.go @@ -0,0 +1,196 @@ +package tests_test + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestPreparedStmt(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + txCtx := tx.WithContext(ctx) + + user := *GetUser("prepared_stmt", Config{}) + + txCtx.Create(&user) + + var result1 User + if err := txCtx.Find(&result1, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + time.Sleep(time.Second) + + var result2 User + if err := tx.Find(&result2, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + user2 := *GetUser("prepared_stmt2", Config{}) + if err := txCtx.Create(&user2).Error; err == nil { + t.Fatalf("should failed to create with timeout context") + } + + if err := tx.Create(&user2).Error; err != nil { + t.Fatalf("failed to create, got error %v", err) + } + + var result3 User + if err := tx.Find(&result3, user2.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } +} + +func TestPreparedStmtFromTransaction(t *testing.T) { + db := DB.Session(&gorm.Session{PrepareStmt: true, SkipDefaultTransaction: true}) + + tx := db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + if err := tx.Error; err != nil { + t.Errorf("Failed to start transaction, got error %v\n", err) + } + + if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Failed to commit transaction, got error %v\n", err) + } + + if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + + tx2 := db.Begin() + if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + tx2.Commit() +} + +func TestPreparedStmtDeadlock(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + user := User{Name: "jinzhu"} + tx.Create(&user) + + var result User + tx.First(&result) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 2) + for _, stmt := range conn.Stmts { + if stmt == nil { + t.Fatalf("stmt cannot bee nil") + } + } + + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + +func TestPreparedStmtError(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + // err prepare + tag := Tag{Locale: "zh"} + tx.Table("users").Find(&tag) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 0) + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + +func TestPreparedStmtInTransaction(t *testing.T) { + user := User{Name: "jinzhu"} + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user) + return errors.New("test") + }); err == nil { + t.Error(err) + } + + var result User + if err := DB.First(&result, user.ID).Error; err == nil { + t.Errorf("Failed, got error: %v", err) + } +} + +func TestPreparedStmtReset(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + user := *GetUser("prepared_stmt_reset", Config{}) + tx = tx.Create(&user) + + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + pdb.Mux.Lock() + if len(pdb.Stmts) == 0 { + pdb.Mux.Unlock() + t.Fatalf("prepared stmt can not be empty") + } + pdb.Mux.Unlock() + + pdb.Reset() + pdb.Mux.Lock() + defer pdb.Mux.Unlock() + if len(pdb.Stmts) != 0 { + t.Fatalf("prepared stmt should be empty") + } +} diff --git a/tests/query_test.go b/tests/query_test.go new file mode 100644 index 00000000..b6bd0736 --- /dev/null +++ b/tests/query_test.go @@ -0,0 +1,1382 @@ +package tests_test + +import ( + "database/sql" + "fmt" + "reflect" + "regexp" + "sort" + "strconv" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestFind(t *testing.T) { + users := []User{ + *GetUser("find", Config{}), + *GetUser("find", Config{}), + *GetUser("find", Config{}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + t.Run("First", func(t *testing.T) { + var first User + if err := DB.Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + CheckUser(t, first, users[0]) + } + }) + + t.Run("Last", func(t *testing.T) { + var last User + if err := DB.Where("name = ?", "find").Last(&last).Error; err != nil { + t.Errorf("errors happened when query last: %v", err) + } else { + CheckUser(t, last, users[2]) + } + }) + + var all []User + if err := DB.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { + t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) + } else { + for idx, user := range users { + t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, all[idx], user) + }) + } + } + + t.Run("FirstMap", func(t *testing.T) { + first := map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := first[dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Age": + if _, ok := first[dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Birthday": + if _, ok := first[dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstMapWithTable", func(t *testing.T) { + first := map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(first[dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstPtrMap", func(t *testing.T) { + first := map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstSliceOfMap", func(t *testing.T) { + allMap := []map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query find: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := allMap[idx][dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Age": + if _, ok := allMap[idx][dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Birthday": + if _, ok := allMap[idx][dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) + + t.Run("FindSliceOfMapWithTable", func(t *testing.T) { + allMap := []map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query find: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) + + var models []User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models[idx], user) + }) + } + } + + // test array + var models2 [3]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models2[idx], user) + }) + } + } + + // test smaller array + var models3 [2]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3)) + } else { + for idx, user := range users[:2] { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models3[idx], user) + }) + } + } + + var none []User + if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) + } +} + +func TestQueryWithAssociation(t *testing.T) { + user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create user: %v", err) + } + + user.CreatedAt = time.Time{} + user.UpdatedAt = time.Time{} + if err := DB.Where(&user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } + + if err := DB.Where(user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } +} + +func TestFindInBatches(t *testing.T) { + users := []User{ + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + } + + DB.Create(&users) + + var ( + results []User + totalBatch int + ) + + if result := DB.Table("users as u").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + + if tx.RowsAffected != 2 { + t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) + } + + if len(results) != 2 { + t.Errorf("Incorrect users length, expects: 2, got %v", len(results)) + } + + for idx := range results { + results[idx].Name = results[idx].Name + "_new" + } + + if err := tx.Save(results).Error; err != nil { + t.Fatalf("failed to save users, got error %v", err) + } + + return nil + }); result.Error != nil || result.RowsAffected != 6 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + if totalBatch != 6 { + t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) + } + + var count int64 + DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count) + if count != 6 { + t.Errorf("incorrect count after update, expects: %v, got %v", 6, count) + } +} + +func TestFindInBatchesWithOffsetLimit(t *testing.T) { + users := []User{ + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + } + + DB.Create(&users) + + var ( + sub, results []User + lastBatch int + ) + + // offset limit + if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error { + results = append(results, sub...) + lastBatch = batch + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + if lastBatch != 3 { + t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch) + } + + targetUsers := users[3:8] + for i := 0; i < len(targetUsers); i++ { + AssertEqual(t, results[i], targetUsers[i]) + } + + var sub1 []User + // limit < batchSize + if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub2 []User + // only offset + if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 7 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub3 []User + if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 4 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } +} + +func TestFindInBatchesWithError(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlserver" { + t.Skip("skip sqlserver due to it will raise data race for invalid sql") + } + + users := []User{ + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + } + + DB.Create(&users) + + var ( + results []User + totalBatch int + ) + + if result := DB.Table("wrong_table").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error == nil || result.RowsAffected > 0 { + t.Fatal("expected errors to have occurred, but nothing happened") + } + if totalBatch != 0 { + t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) + } + + if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error != gorm.ErrPrimaryKeyRequired { + t.Fatal("expected errors to have occurred, but nothing happened") + } +} + +func TestFillSmallerStruct(t *testing.T) { + user := User{Name: "SmallerUser", Age: 100} + DB.Save(&user) + type SimpleUser struct { + ID int64 + Name string + UpdatedAt time.Time + CreatedAt time.Time + } + + var simpleUser SimpleUser + if err := DB.Table("users").Where("name = ?", user.Name).First(&simpleUser).Error; err != nil { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUser, "Name", "ID", "UpdatedAt", "CreatedAt") + + var simpleUser2 SimpleUser + if err := DB.Model(&User{}).Select("id").First(&simpleUser2, user.ID).Error; err != nil { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUser2, "ID") + + var simpleUsers []SimpleUser + if err := DB.Model(&User{}).Select("id").Find(&simpleUsers, user.ID).Error; err != nil || len(simpleUsers) != 1 { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUsers[0], "ID") + + result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID) + + if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]*User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } +} + +func TestFillSmallerStructWithAllFields(t *testing.T) { + user := User{Name: "SmallerUser", Age: 100} + DB.Save(&user) + type SimpleUser struct { + ID int64 + Name string + UpdatedAt time.Time + CreatedAt time.Time + } + var simpleUsers []SimpleUser + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + + result := dryDB.Model(&User{}).Find(&simpleUsers, user.ID) + if !regexp.MustCompile("SELECT .users.*id.*users.*name.*users.*updated_at.*users.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&[]User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&[]*User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } +} + +func TestNot(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Not(map[string]interface{}{"name": "jinzhu"}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu1").Not("name = ?", "jinzhu2").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ AND NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(map[string]interface{}{"name": []string{}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IS NOT NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{1, 2}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + +func TestNotWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Not(map[string]interface{}{"users.name": "jinzhu"}).Find(&User{}) + + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu1").Not("users.name = ?", "jinzhu2").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ AND NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not("users.name = ?", "jinzhu").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{1, 2}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + +func TestOr(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + var count int64 + result := dryDB.Model(&User{}).Or("role = ?", "admin").Count(&count) + if !regexp.MustCompile("SELECT count\\(\\*\\) FROM .*users.* WHERE role = .+ AND .*users.*\\..*deleted_at.* IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin").Or("role = ?", "admin")).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND (.*role.* = .+ OR .*role.* = .+)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*name.* AND .*age.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } +} + +func TestOrWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*users.*name.* AND .*users.*age.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } +} + +func TestPluck(t *testing.T) { + users := []*User{ + GetUser("pluck-user1", Config{}), + GetUser("pluck-user2", Config{}), + GetUser("pluck-user3", Config{}), + } + + DB.Create(&users) + + var names []string + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("name", &names).Error; err != nil { + t.Errorf("got error when pluck name: %v", err) + } + + var names2 []string + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil { + t.Errorf("got error when pluck name: %v", err) + } + + sort.Slice(names2, func(i, j int) bool { return names2[i] < names2[j] }) + AssertEqual(t, names, names2) + + var ids []int + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil { + t.Errorf("got error when pluck id: %v", err) + } + + for idx, name := range names { + if name != users[idx].Name { + t.Errorf("Unexpected result on pluck name, got %+v", names) + } + } + + for idx, id := range ids { + if int(id) != int(users[idx].ID) { + t.Errorf("Unexpected result on pluck id, got %+v", ids) + } + } + + var times []time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range times { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var ptrtimes []*time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range ptrtimes { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var nulltimes []sql.NullTime + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range nulltimes { + AssertEqual(t, tv.Time, users[idx].CreatedAt) + } +} + +func TestSelect(t *testing.T) { + user := User{Name: "SelectUser1"} + DB.Save(&user) + + var result User + DB.Where("name = ?", user.Name).Select("name").Find(&result) + if result.ID != 0 { + t.Errorf("Should not have ID because only selected name, %+v", result.ID) + } + + if user.Name != result.Name { + t.Errorf("Should have user Name when selected it") + } + + var result2 User + DB.Where("name = ?", user.Name).Select("name as name").Find(&result2) + if result2.ID != 0 { + t.Errorf("Should not have ID because only selected name, %+v", result2.ID) + } + + if user.Name != result2.Name { + t.Errorf("Should have user Name when selected it") + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + r := dryDB.Select("name", "age").Find(&User{}) + if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with strings, but got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select([]string{"name", "age"}).Find(&User{}) + if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with slice, but got %v", r.Statement.SQL.String()) + } + + // SELECT COALESCE(age,'42') FROM users; + r = dryDB.Table("users").Select("COALESCE(age,?)", 42).Find(&User{}) + if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) + } + + // named arguments + r = dryDB.Table("users").Select("COALESCE(age, @default)", sql.Named("default", 42)).Find(&User{}) + if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) + } + + if _, err := DB.Table("users").Select("COALESCE(age,?)", "42").Rows(); err != nil { + t.Fatalf("Failed, got error: %v", err) + } + + r = dryDB.Select("u.*").Table("users as u").First(&User{}, user.ID) + if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID) + if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) + } +} + +func TestOmit(t *testing.T) { + user := User{Name: "OmitUser1", Age: 20} + DB.Save(&user) + + var result User + DB.Where("name = ?", user.Name).Omit("name").Find(&result) + if result.ID == 0 { + t.Errorf("Should not have ID because only selected name, %+v", result.ID) + } + + if result.Name != "" || result.Age != 20 { + t.Errorf("User Name should be omitted, got %v, Age should be ok, got %v", result.Name, result.Age) + } +} + +func TestOmitWithAllFields(t *testing.T) { + user := User{Name: "OmitUser1", Age: 20} + DB.Save(&user) + + var userResult User + DB.Session(&gorm.Session{QueryFields: true}).Where("users.name = ?", user.Name).Omit("name").Find(&userResult) + if userResult.ID == 0 { + t.Errorf("Should not have ID because only selected name, %+v", userResult.ID) + } + + if userResult.Name != "" || userResult.Age != 20 { + t.Errorf("User Name should be omitted, got %v, Age should be ok, got %v", userResult.Name, userResult.Age) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*birthday" + + ".*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Omit("name, age").Find(&User{}) + if !regexp.MustCompile(userQuery).MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL must include table name and selected fields, got %v", result.Statement.SQL.String()) + } +} + +func TestPluckWithSelect(t *testing.T) { + users := []User{ + {Name: "pluck_with_select_1", Age: 25}, + {Name: "pluck_with_select_2", Age: 26}, + } + + DB.Create(&users) + + var userAges []int + err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error + if err != nil { + t.Fatalf("got error when pluck user_age: %v", err) + } + + sort.Ints(userAges) + + AssertEqual(t, userAges, []int{26, 27}) +} + +func TestSelectWithVariables(t *testing.T) { + DB.Save(&User{Name: "select_with_variables"}) + + rows, _ := DB.Table("users").Where("name = ?", "select_with_variables").Select("? as fake", gorm.Expr("name")).Rows() + + if !rows.Next() { + t.Errorf("Should have returned at least one row") + } else { + columns, _ := rows.Columns() + AssertEqual(t, columns, []string{"fake"}) + } + + rows.Close() +} + +func TestSelectWithArrayInput(t *testing.T) { + DB.Save(&User{Name: "select_with_array", Age: 42}) + + var user User + DB.Select([]string{"name", "age"}).Where("age = 42 AND name = ?", "select_with_array").First(&user) + + if user.Name != "select_with_array" || user.Age != 42 { + t.Errorf("Should have selected both age and name") + } +} + +func TestCustomizedTypePrimaryKey(t *testing.T) { + type ID uint + type CustomizedTypePrimaryKey struct { + ID ID + Name string + } + + DB.Migrator().DropTable(&CustomizedTypePrimaryKey{}) + if err := DB.AutoMigrate(&CustomizedTypePrimaryKey{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + p1 := CustomizedTypePrimaryKey{Name: "p1"} + p2 := CustomizedTypePrimaryKey{Name: "p2"} + p3 := CustomizedTypePrimaryKey{Name: "p3"} + DB.Create(&p1) + DB.Create(&p2) + DB.Create(&p3) + + var p CustomizedTypePrimaryKey + + if err := DB.First(&p, p2.ID).Error; err != nil { + t.Errorf("No error should returns, but got %v", err) + } + + AssertEqual(t, p, p2) + + if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { + t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) + } + + AssertEqual(t, p, p2) +} + +func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { + type AddressByZipCode struct { + ZipCode string `gorm:"primary_key"` + Address string + } + + DB.Migrator().DropTable(&AddressByZipCode{}) + if err := DB.AutoMigrate(&AddressByZipCode{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + address := AddressByZipCode{ZipCode: "00501", Address: "Holtsville"} + DB.Create(&address) + + var result AddressByZipCode + DB.First(&result, "00501") + + AssertEqual(t, result, address) +} + +func TestSearchWithEmptyChain(t *testing.T) { + user := User{Name: "search_with_empty_chain", Age: 1} + DB.Create(&user) + + var result User + if DB.Where("").Where("").First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty strings") + } + + result = User{} + if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty struct") + } + + result = User{} + if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty map") + } +} + +func TestOrder(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Order("").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order(nil).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("age desc, name").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc, name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("age desc").Order("name").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc,name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + stmt := dryDB.Clauses(clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }).Find(&User{}).Statement + + explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { + t.Fatalf("Build Order condition, but got %v", explainedSQL) + } +} + +func TestOrderWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name.*users.*age" + + ".*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Order("users.age desc, users.name").Find(&User{}) + if !regexp.MustCompile(userQuery + "users.age desc, users.name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("users.age desc").Order("users.name").Find(&User{}) + if !regexp.MustCompile(userQuery + "ORDER BY users.age desc,users.name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + stmt := dryDB.Clauses(clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }).Find(&User{}).Statement + + explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(userQuery + "ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { + t.Fatalf("Build Order condition, but got %v", explainedSQL) + } +} + +func TestLimit(t *testing.T) { + users := []User{ + {Name: "LimitUser1", Age: 1}, + {Name: "LimitUser2", Age: 10}, + {Name: "LimitUser3", Age: 20}, + {Name: "LimitUser4", Age: 10}, + {Name: "LimitUser5", Age: 20}, + {Name: "LimitUser6", Age: 20}, + } + + DB.Create(&users) + + var users1, users2, users3 []User + DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) + + if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { + t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3)) + } +} + +func TestOffset(t *testing.T) { + for i := 0; i < 20; i++ { + DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) + } + var users1, users2, users3, users4 []User + + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + + if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { + t.Errorf("Offset should work") + } + + DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + + if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { + t.Errorf("Offset should work without limit.") + } +} + +func TestSearchWithMap(t *testing.T) { + users := []User{ + *GetUser("map_search_user1", Config{}), + *GetUser("map_search_user2", Config{}), + *GetUser("map_search_user3", Config{}), + *GetUser("map_search_user4", Config{Company: true}), + } + + DB.Create(&users) + + var user User + DB.First(&user, map[string]interface{}{"name": users[0].Name}) + CheckUser(t, user, users[0]) + + user = User{} + DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user) + CheckUser(t, user, users[1]) + + var results []User + DB.Where(map[string]interface{}{"name": users[2].Name}).Find(&results) + if len(results) != 1 { + t.Fatalf("Search all records with inline map") + } + + CheckUser(t, results[0], users[2]) + + var results2 []User + DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": nil}) + if len(results2) != 0 { + t.Errorf("Search all records with inline map containing null value finding 0 records") + } + + DB.Find(&results2, map[string]interface{}{"name": users[0].Name, "company_id": nil}) + if len(results2) != 1 { + t.Errorf("Search all records with inline map containing null value finding 1 record") + } + + DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": users[3].CompanyID}) + if len(results2) != 1 { + t.Errorf("Search all records with inline multiple value map") + } +} + +func TestSearchWithStruct(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} + +func TestSubQuery(t *testing.T) { + users := []User{ + {Name: "subquery_1", Age: 10}, + {Name: "subquery_2", Age: 20}, + {Name: "subquery_3", Age: 30}, + {Name: "subquery_4", Age: 40}, + } + + DB.Create(&users) + + if err := DB.Select("*").Where("name IN (?)", DB.Select("name").Table("users").Where("name LIKE ?", "subquery_%")).Find(&users).Error; err != nil { + t.Fatalf("got error: %v", err) + } + + if len(users) != 4 { + t.Errorf("Four users should be found, instead found %d", len(users)) + } + + DB.Select("*").Where("name LIKE ?", "subquery%").Where("age >= (?)", DB. + Select("AVG(age)").Table("users").Where("name LIKE ?", "subquery%")).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } +} + +func TestSubQueryWithRaw(t *testing.T) { + users := []User{ + {Name: "subquery_raw_1", Age: 10}, + {Name: "subquery_raw_2", Age: 20}, + {Name: "subquery_raw_3", Age: 30}, + {Name: "subquery_raw_4", Age: 40}, + } + DB.Create(&users) + + var count int64 + err := DB.Raw("select count(*) from (?) tmp where 1 = ? AND name IN (?)", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"}), 1, DB.Raw("select name from users where age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"})).Scan(&count).Error + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). + Group("name"), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 1 { + t.Errorf("Row count must be 1, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("name LIKE ?", "subquery_raw%"). + Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}). + Group("name"), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } +} + +func TestSubQueryWithHaving(t *testing.T) { + users := []User{ + {Name: "subquery_having_1", Age: 10}, + {Name: "subquery_having_2", Age: 20}, + {Name: "subquery_having_3", Age: 30}, + {Name: "subquery_having_4", Age: 40}, + } + DB.Create(&users) + + var results []User + DB.Select("AVG(age) as avgage").Where("name LIKE ?", "subquery_having%").Group("name").Having("AVG(age) > (?)", DB. + Select("AVG(age)").Where("name LIKE ?", "subquery_having%").Table("users")).Find(&results) + + if len(results) != 2 { + t.Errorf("Two user group should be found, instead found %d", len(results)) + } +} + +func TestScanNullValue(t *testing.T) { + user := GetUser("scan_null_value", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Update("age", nil).Error; err != nil { + t.Fatalf("failed to update column age for struct, got error %v", err) + } + + var result User + if err := DB.First(&result, "id = ?", user.ID).Error; err != nil { + t.Fatalf("failed to query struct data with null age, got error %v", err) + } + + AssertEqual(t, result, user) + + users := []User{ + *GetUser("scan_null_value_for_slice_1", Config{}), + *GetUser("scan_null_value_for_slice_2", Config{}), + *GetUser("scan_null_value_for_slice_3", Config{}), + } + DB.Create(&users) + + if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil { + t.Fatalf("failed to update column age for struct, got error %v", err) + } + + var results []User + if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil { + t.Fatalf("failed to query slice data with null age, got error %v", err) + } +} + +func TestQueryWithTableAndConditions(t *testing.T) { + result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) + + if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} + +func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) { + result := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) + userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* FROM .user. " + + if !regexp.MustCompile(userQuery + `WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} + +type DoubleInt64 struct { + data int64 +} + +func (t *DoubleInt64) Scan(val interface{}) error { + switch v := val.(type) { + case int64: + t.data = v * 2 + return nil + default: + return fmt.Errorf("DoubleInt64 cant not scan with:%v", v) + } +} + +// https://github.com/go-gorm/gorm/issues/5091 +func TestQueryScannerWithSingleColumn(t *testing.T) { + user := User{Name: "scanner_raw_1", Age: 10} + DB.Create(&user) + + var result1 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck( + "age", &result1).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result1.data, 20) + + var result2 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select( + "age").Scan(&result2).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result2.data, 20) +} + +func TestQueryResetNullValue(t *testing.T) { + type QueryResetItem struct { + ID string `gorm:"type:varchar(5)"` + Name string + } + + type QueryResetNullValue struct { + ID int + Name string `gorm:"default:NULL"` + Flag bool `gorm:"default:NULL"` + Number1 int64 `gorm:"default:NULL"` + Number2 uint64 `gorm:"default:NULL"` + Number3 float64 `gorm:"default:NULL"` + Now *time.Time `gorm:"defalut:NULL"` + Item1Id string + Item1 *QueryResetItem `gorm:"references:ID"` + Item2Id string + Item2 *QueryResetItem `gorm:"references:ID"` + } + + DB.Migrator().DropTable(&QueryResetNullValue{}, &QueryResetItem{}) + DB.AutoMigrate(&QueryResetNullValue{}, &QueryResetItem{}) + + now := time.Now() + q1 := QueryResetNullValue{ + Name: "name", + Flag: true, + Number1: 100, + Number2: 200, + Number3: 300.1, + Now: &now, + Item1: &QueryResetItem{ + ID: "u_1_1", + Name: "item_1_1", + }, + Item2: &QueryResetItem{ + ID: "u_1_2", + Name: "item_1_2", + }, + } + + q2 := QueryResetNullValue{ + Item1: &QueryResetItem{ + ID: "u_2_1", + Name: "item_2_1", + }, + Item2: &QueryResetItem{ + ID: "u_2_2", + Name: "item_2_2", + }, + } + + var err error + err = DB.Create(&q1).Error + if err != nil { + t.Errorf("failed to create:%v", err) + } + + err = DB.Create(&q2).Error + if err != nil { + t.Errorf("failed to create:%v", err) + } + + var qs []QueryResetNullValue + err = DB.Joins("Item1").Joins("Item2").Find(&qs).Error + if err != nil { + t.Errorf("failed to find:%v", err) + } + + if len(qs) != 2 { + t.Fatalf("find count not equal:%d", len(qs)) + } + + AssertEqual(t, q1, qs[0]) + AssertEqual(t, q2, qs[1]) +} + +func TestQueryError(t *testing.T) { + type P struct{} + var p1 P + err := DB.Take(&p1, 1).Error + AssertEqual(t, err, gorm.ErrModelAccessibleFieldsRequired) + + var p2 interface{} + + err = DB.Table("ps").Clauses(clause.Eq{Column: clause.Column{ + Table: clause.CurrentTable, Name: clause.PrimaryKey, + }, Value: 1}).Scan(&p2).Error + AssertEqual(t, err, gorm.ErrModelValueRequired) +} diff --git a/tests/scan_test.go b/tests/scan_test.go new file mode 100644 index 00000000..6f2e9f54 --- /dev/null +++ b/tests/scan_test.go @@ -0,0 +1,242 @@ +package tests_test + +import ( + "reflect" + "sort" + "strings" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type PersonAddressInfo struct { + Person *Person `gorm:"embedded"` + Address *Address `gorm:"embedded"` +} + +func TestScan(t *testing.T) { + user1 := User{Name: "ScanUser1", Age: 1} + user2 := User{Name: "ScanUser2", Age: 10} + user3 := User{Name: "ScanUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + ID uint + Name string + Age int + } + + var res result + DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + var resPointer *result + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resPointer.ID != user3.ID || resPointer.Name != user3.Name || resPointer.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) + if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) + } + + DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + doubleAgeRes := &result{} + if err := DB.Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { + t.Errorf("Scan to pointer of pointer") + } + + if doubleAgeRes.Age != int(res.Age)*2 { + t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age) + } + + var results []result + DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) + + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) <= -1 + }) + + if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { + t.Errorf("Scan into struct map, got %#v", results) + } + + type ID uint64 + var id ID + DB.Raw("select id from users where id = ?", user2.ID).Scan(&id) + if uint(id) != user2.ID { + t.Errorf("Failed to scan to customized data type") + } + + var resInt interface{} + resInt = &User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resInt.(*User).ID != user3.ID || resInt.(*User).Name != user3.Name || resInt.(*User).Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt, user3) + } + + var resInt2 interface{} + resInt2 = &User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt2).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resInt2.(*User).ID != user3.ID || resInt2.(*User).Name != user3.Name || resInt2.(*User).Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt2, user3) + } + + var resInt3 interface{} + resInt3 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt3).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt3.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt3, user3) + } + + var resInt4 interface{} + resInt4 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt4).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3) + } + + var resInt5 interface{} + resInt5 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id IN ?", []uint{user1.ID, user2.ID, user3.ID}).Find(&resInt5).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt5.([]User); len(rus) != 3 { + t.Fatalf("Scan into struct should work, got %+v, len %v", resInt5, len(rus)) + } +} + +func TestScanRows(t *testing.T) { + user1 := User{Name: "ScanRowsUser1", Age: 1} + user2 := User{Name: "ScanRowsUser2", Age: 10} + user3 := User{Name: "ScanRowsUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + type Result struct { + Name string + Age int + } + + var results []Result + for rows.Next() { + var result Result + if err := DB.ScanRows(rows, &result); err != nil { + t.Errorf("should get no error, but got %v", err) + } + results = append(results, result) + } + + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) <= -1 + }) + + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { + t.Errorf("Should find expected results") + } + + var ages int + if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages) + } + + var name string + if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + } +} + +func TestScanToEmbedded(t *testing.T) { + person1 := Person{Name: "person 1"} + person2 := Person{Name: "person 2"} + DB.Save(&person1).Save(&person2) + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + DB.Save(&address1).Save(&address2) + + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)}) + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)}) + DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)}) + + var personAddressInfoList []*PersonAddressInfo + if err := DB.Select("people.*, addresses.*"). + Table("people"). + Joins("inner join person_addresses on people.id = person_addresses.person_id"). + Joins("inner join addresses on person_addresses.address_id = addresses.id"). + Find(&personAddressInfoList).Error; err != nil { + t.Errorf("Failed to run join query, got error: %v", err) + } + + personMatched := false + addressMatched := false + + for _, info := range personAddressInfoList { + if info.Person == nil { + t.Fatalf("Failed, expected not nil, got person nil") + } + if info.Address == nil { + t.Fatalf("Failed, expected not nil, got address nil") + } + if info.Person.ID == person1.ID { + personMatched = true + if info.Person.Name != person1.Name { + t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name) + } + } + if info.Address.ID == address1.ID { + addressMatched = true + if info.Address.Name != address1.Name { + t.Errorf("Failed, expected %v, got %v", address1.Name, info.Address.Name) + } + } + } + + if !personMatched { + t.Errorf("Failed, no person matched") + } + if !addressMatched { + t.Errorf("Failed, no address matched") + } + + personDupField := Person{ID: person1.ID} + if err := DB.Select("people.id, people.*"). + First(&personDupField).Error; err != nil { + t.Errorf("Failed to run join query, got error: %v", err) + } + AssertEqual(t, person1, personDupField) + + user := User{ + Name: "TestScanToEmbedded_1", + Manager: &User{ + Name: "TestScanToEmbedded_1_m1", + Manager: &User{Name: "TestScanToEmbedded_1_m1_m1"}, + }, + } + DB.Create(&user) + + type UserScan struct { + ID uint + Name string + ManagerID *uint + } + var user2 UserScan + err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error + AssertEqual(t, err, nil) +} diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go new file mode 100644 index 00000000..14121699 --- /dev/null +++ b/tests/scanner_valuer_test.go @@ -0,0 +1,393 @@ +package tests_test + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "reflect" + "regexp" + "strconv" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestScannerValuer(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + data := ScannerValuerStruct{ + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + Male: sql.NullBool{Bool: true, Valid: true}, + Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, + Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}}, + Password: EncryptedData("pass1"), + Bytes: []byte("byte"), + Num: 18, + Strings: StringsSlice{"a", "b", "c"}, + Structs: StructsSlice{ + {"name1", "value1"}, + {"name2", "value2"}, + }, + Role: Role{Name: "admin"}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, + } + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("No error should happened when create scanner valuer struct, but got %v", err) + } + + var result ScannerValuerStruct + + if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil { + t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) + } + + if result.ExampleStructPtr.Val != "value2" { + t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value1" { + t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct) + } + AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") +} + +func TestScannerValuerWithFirstOrCreate(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, + } + + var result ScannerValuerStruct + tx := DB.Where(data).FirstOrCreate(&result) + + if tx.RowsAffected != 1 { + t.Errorf("RowsAffected should be 1 after create some record") + } + + if tx.Error != nil { + t.Errorf("Should not raise any error, but got %v", tx.Error) + } + + AssertObjEqual(t, result, data, "Name", "Gender", "Age") + + if err := DB.Where(data).Assign(ScannerValuerStruct{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&result).Error; err != nil { + t.Errorf("Should not raise any error, but got %v", err) + } + + if result.Age.Int64 != 18 { + t.Errorf("should update age to 18") + } + + var result2 ScannerValuerStruct + if err := DB.First(&result2, result.ID).Error; err != nil { + t.Errorf("got error %v when query with %v", err, result.ID) + } + + AssertObjEqual(t, result2, result, "ID", "CreatedAt", "UpdatedAt", "Name", "Gender", "Age") +} + +func TestInvalidValuer(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Password: EncryptedData("xpass1"), + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, + } + + if err := DB.Create(&data).Error; err == nil { + t.Errorf("Should failed to create data with invalid data") + } + + data.Password = EncryptedData("pass1") + if err := DB.Create(&data).Error; err != nil { + t.Errorf("Should got no error when creating data, but got %v", err) + } + + if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil { + t.Errorf("Should failed to update data with invalid data") + } + + if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil { + t.Errorf("Should got no error update data with valid data, but got %v", err) + } + + AssertEqual(t, data.Password, EncryptedData("newpass")) +} + +type ScannerValuerStruct struct { + gorm.Model + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Allergen NullString + Password EncryptedData + Bytes []byte + Num Num + Strings StringsSlice + Structs StructsSlice + Role Role + UserID *sql.NullInt64 + User User + EmptyTime EmptyTime + ExampleStruct ExampleStruct + ExampleStructPtr *ExampleStruct +} + +type EncryptedData []byte + +func (data *EncryptedData) Scan(value interface{}) error { + if b, ok := value.([]byte); ok { + if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { + return errors.New("Too short") + } + + *data = b[3:] + return nil + } else if s, ok := value.(string); ok { + *data = []byte(s)[3:] + return nil + } + + return errors.New("Bytes expected") +} + +func (data EncryptedData) Value() (driver.Value, error) { + if len(data) > 0 && data[0] == 'x' { + // needed to test failures + return nil, errors.New("Should not start with 'x'") + } + + // prepend asterisks + return append([]byte("***"), data...), nil +} + +type Num int64 + +func (i *Num) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + n, _ := strconv.Atoi(string(s)) + *i = Num(n) + case int64: + *i = Num(s) + default: + return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) + } + return nil +} + +type StringsSlice []string + +func (l StringsSlice) Value() (driver.Value, error) { + bytes, err := json.Marshal(l) + return string(bytes), err +} + +func (l *StringsSlice) Scan(input interface{}) error { + switch value := input.(type) { + case string: + return json.Unmarshal([]byte(value), l) + case []byte: + return json.Unmarshal(value, l) + default: + return errors.New("not supported") + } +} + +type ExampleStruct struct { + Name string + Val string +} + +func (ExampleStruct) GormDataType() string { + return "bytes" +} + +func (s ExampleStruct) Value() (driver.Value, error) { + if len(s.Name) == 0 { + return nil, nil + } + // for test, has no practical meaning + s.Name = "" + return json.Marshal(s) +} + +func (s *ExampleStruct) Scan(src interface{}) error { + switch value := src.(type) { + case string: + return json.Unmarshal([]byte(value), s) + case []byte: + return json.Unmarshal(value, s) + default: + return errors.New("not supported") + } +} + +type StructsSlice []ExampleStruct + +func (l StructsSlice) Value() (driver.Value, error) { + bytes, err := json.Marshal(l) + return string(bytes), err +} + +func (l *StructsSlice) Scan(input interface{}) error { + switch value := input.(type) { + case string: + return json.Unmarshal([]byte(value), l) + case []byte: + return json.Unmarshal(value, l) + default: + return errors.New("not supported") + } +} + +type Role struct { + Name string `gorm:"size:256"` +} + +func (role *Role) Scan(value interface{}) error { + if b, ok := value.([]uint8); ok { + role.Name = string(b) + } else { + role.Name = value.(string) + } + return nil +} + +func (role Role) Value() (driver.Value, error) { + return role.Name, nil +} + +func (role Role) IsAdmin() bool { + return role.Name == "admin" +} + +type EmptyTime struct { + time.Time +} + +func (t *EmptyTime) Scan(v interface{}) error { + nullTime := sql.NullTime{} + err := nullTime.Scan(v) + t.Time = nullTime.Time + return err +} + +func (t EmptyTime) Value() (driver.Value, error) { + return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil +} + +type NullString struct { + sql.NullString +} + +type Point struct { + X, Y int +} + +func (point Point) GormDataType() string { + return "geo" +} + +func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + return clause.Expr{ + SQL: "ST_PointFromText(?)", + Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)}, + } +} + +func TestGORMValuer(t *testing.T) { + type UserWithPoint struct { + Name string + Point Point + } + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if stmt.SQL.String() == "" || len(stmt.Vars) != 2 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Session(&gorm.Session{ + AllowGlobalUpdate: true, + }).Model(&UserWithPoint{}).Updates(UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { + t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } +} diff --git a/tests/scopes_test.go b/tests/scopes_test.go new file mode 100644 index 00000000..52c6b37b --- /dev/null +++ b/tests/scopes_test.go @@ -0,0 +1,125 @@ +package tests_test + +import ( + "context" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func NameIn1And2(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) +} + +func NameIn2And3(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) +} + +func NameIn(names []string) func(d *gorm.DB) *gorm.DB { + return func(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", names) + } +} + +func TestScopes(t *testing.T) { + users := []*User{ + GetUser("ScopeUser1", Config{}), + GetUser("ScopeUser2", Config{}), + GetUser("ScopeUser3", Config{}), + } + + DB.Create(&users) + + var users1, users2, users3 []User + DB.Scopes(NameIn1And2).Find(&users1) + if len(users1) != 2 { + t.Errorf("Should found two users's name in 1, 2, but got %v", len(users1)) + } + + DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) + if len(users2) != 1 { + t.Errorf("Should found one user's name is 2, but got %v", len(users2)) + } + + DB.Scopes(NameIn([]string{users[0].Name, users[2].Name})).Find(&users3) + if len(users3) != 2 { + t.Errorf("Should found two users's name in 1, 3, but got %v", len(users3)) + } + + db := DB.Scopes(func(tx *gorm.DB) *gorm.DB { + return tx.Table("custom_table") + }).Session(&gorm.Session{}) + + db.AutoMigrate(&User{}) + if db.Find(&User{}).Statement.Table != "custom_table" { + t.Errorf("failed to call Scopes") + } + + result := DB.Scopes(NameIn1And2, func(tx *gorm.DB) *gorm.DB { + return tx.Session(&gorm.Session{}) + }).Find(&users1) + + if result.RowsAffected != 2 { + t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) + } + + var maxId int64 + userTable := func(db *gorm.DB) *gorm.DB { + return db.WithContext(context.Background()).Table("users") + } + if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil { + t.Errorf("select max(id)") + } +} + +func TestComplexScopes(t *testing.T) { + tests := []struct { + name string + queryFn func(tx *gorm.DB) *gorm.DB + expected string + }{ + { + name: "depth_1", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, + }, { + name: "depth_1_pre_cond", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Where("z = 0").Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`, + }, { + name: "depth_2", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) }, + func(d *gorm.DB) *gorm.DB { + return d. + Or(d.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") }, + )). + Or("c = 3") + }, + func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn)) + }) + } +} diff --git a/tests/serializer_test.go b/tests/serializer_test.go new file mode 100644 index 00000000..f1b8a336 --- /dev/null +++ b/tests/serializer_test.go @@ -0,0 +1,229 @@ +package tests_test + +import ( + "bytes" + "context" + "fmt" + "reflect" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + . "gorm.io/gorm/utils/tests" +) + +type SerializerStruct struct { + gorm.Model + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Roles2 *Roles `gorm:"serializer:json"` + Roles3 *Roles `gorm:"serializer:json;not null"` + Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` + CreatedTime int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type + CustomSerializerString string `gorm:"serializer:custom"` + EncryptedString EncryptedString +} + +type SerializerPostgresStruct struct { + gorm.Model + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Roles2 *Roles `gorm:"serializer:json"` + Roles3 *Roles `gorm:"serializer:json;not null"` + Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` + CreatedTime int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type + CustomSerializerString string `gorm:"serializer:custom"` + EncryptedString EncryptedString +} + +func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" } + +func adaptorSerializerModel(s *SerializerStruct) interface{} { + if DB.Dialector.Name() == "postgres" { + sps := SerializerPostgresStruct(*s) + return &sps + } + return s +} + +type Roles []string + +type Job struct { + Title string + Number int + Location string + IsIntern bool +} + +type EncryptedString string + +func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + *es = EncryptedString(bytes.TrimPrefix(value, []byte("hello"))) + case string: + *es = EncryptedString(strings.TrimPrefix(value, "hello")) + default: + return fmt.Errorf("unsupported data %#v", dbValue) + } + return nil +} + +func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return "hello" + string(es), nil +} + +type CustomSerializer struct { + prefix []byte +} + +func NewCustomSerializer(prefix string) *CustomSerializer { + return &CustomSerializer{prefix: []byte(prefix)} +} + +func (c *CustomSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + err = field.Set(ctx, dst, bytes.TrimPrefix(value, c.prefix)) + case string: + err = field.Set(ctx, dst, strings.TrimPrefix(value, string(c.prefix))) + default: + err = fmt.Errorf("unsupported data %#v", dbValue) + } + return err +} + +func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return fmt.Sprintf("%s%s", c.prefix, fieldValue), nil +} + +func TestSerializer(t *testing.T) { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt := createdAt.Unix() + + data := SerializerStruct{ + Name: []byte("jinzhu"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + UpdatedTime: &updatedAt, + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Kenmawr", + IsIntern: false, + }, + CustomSerializerString: "world", + } + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("failed to create data, got error %v", err) + } + + var result SerializerStruct + if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result, data) + + if err := DB.Model(&result).Update("roles", "").Error; err != nil { + t.Fatalf("failed to update data's roles, got error %v", err) + } + + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } +} + +func TestSerializerZeroValue(t *testing.T) { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + data := SerializerStruct{} + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("failed to create data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result, data) + + if err := DB.Model(&result).Update("roles", "").Error; err != nil { + t.Fatalf("failed to update data's roles, got error %v", err) + } + + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } +} + +func TestSerializerAssignFirstOrCreate(t *testing.T) { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + data := SerializerStruct{ + Name: []byte("ag9920"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jing1", "age": 11}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Shadyside", + IsIntern: false, + }, + CustomSerializerString: "world", + } + + // first time insert record + out := SerializerStruct{} + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + AssertEqual(t, result, out) + + // update record + data.Roles = append(data.Roles, "r3") + data.JobInfo.Location = "Gates Hillman Complex" + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result.Roles, data.Roles) + AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location) +} diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go new file mode 100644 index 00000000..179ae426 --- /dev/null +++ b/tests/soft_delete_test.go @@ -0,0 +1,169 @@ +package tests_test + +import ( + "database/sql" + "encoding/json" + "errors" + "regexp" + "testing" + + "github.com/jinzhu/now" + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestSoftDelete(t *testing.T) { + user := *GetUser("SoftDelete", Config{}) + DB.Save(&user) + + var count int64 + var age uint + + if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) + } + + if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + + if err := DB.Delete(&user).Error; err != nil { + t.Fatalf("No error should happen when soft delete user, but got %v", err) + } + + if sql.NullTime(user.DeletedAt).Time.IsZero() { + t.Fatalf("user's deleted at is zero") + } + + sql := DB.Session(&gorm.Session{DryRun: true}).Delete(&user).Statement.SQL.String() + if !regexp.MustCompile(`UPDATE .users. SET .deleted_at.=.* WHERE .users.\..id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = DB.Session(&gorm.Session{DryRun: true}).Table("user u").Select("name").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`SELECT .name. FROM user u WHERE .u.\..deleted_at. IS NULL`).MatchString(sql) { + t.Errorf("Table with escape character, got %v", sql) + } + + if DB.First(&User{}, "name = ?", user.Name).Error == nil { + t.Errorf("Can't find a soft deleted record") + } + + count = 0 + if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 0 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) + } + + age = 0 + if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != 0 { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) + } + + count = 0 + if DB.Unscoped().Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) + } + + age = 0 + if DB.Unscoped().Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + + DB.Unscoped().Delete(&user) + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("Can't find permanently deleted record") + } +} + +func TestDeletedAtUnMarshal(t *testing.T) { + expected := &gorm.Model{} + b, _ := json.Marshal(expected) + + result := &gorm.Model{} + _ = json.Unmarshal(b, result) + if result.DeletedAt != expected.DeletedAt { + t.Errorf("Failed, result.DeletedAt: %v is not same as expected.DeletedAt: %v", result.DeletedAt, expected.DeletedAt) + } +} + +func TestDeletedAtOneOr(t *testing.T) { + actualSQL := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Or("id = ?", 1).Find(&User{}) + }) + + if !regexp.MustCompile(` WHERE id = 1 AND .users.\..deleted_at. IS NULL`).MatchString(actualSQL) { + t.Fatalf("invalid sql generated, got %v", actualSQL) + } +} + +func TestSoftDeleteZeroValue(t *testing.T) { + type SoftDeleteBook struct { + ID uint + Name string + Pages uint + DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"` + } + DB.Migrator().DropTable(&SoftDeleteBook{}) + if err := DB.AutoMigrate(&SoftDeleteBook{}); err != nil { + t.Fatalf("failed to auto migrate soft delete table") + } + + book := SoftDeleteBook{Name: "jinzhu", Pages: 10} + DB.Save(&book) + + var count int64 + if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) + } + + var pages uint + if DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { + t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages) + } + + if err := DB.Delete(&book).Error; err != nil { + t.Fatalf("No error should happen when soft delete user, but got %v", err) + } + + zeroTime, _ := now.Parse("1970-01-01 00:00:01") + if book.DeletedAt.Time.Equal(zeroTime) { + t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt) + } + + if DB.First(&SoftDeleteBook{}, "name = ?", book.Name).Error == nil { + t.Errorf("Can't find a soft deleted record") + } + + count = 0 + if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) + } + + pages = 0 + if err := DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 { + t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err) + } + + if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; err != nil { + t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) + } + + count = 0 + if DB.Unscoped().Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) + } + + pages = 0 + if DB.Unscoped().Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages) + } + + DB.Unscoped().Delete(&book) + if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("Can't find permanently deleted record") + } +} diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go new file mode 100644 index 00000000..022e0495 --- /dev/null +++ b/tests/sql_builder_test.go @@ -0,0 +1,499 @@ +package tests_test + +import ( + "regexp" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestRow(t *testing.T) { + user1 := User{Name: "RowUser1", Age: 1} + user2 := User{Name: "RowUser2", Age: 10} + user3 := User{Name: "RowUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() + + var age int64 + if err := row.Scan(&age); err != nil { + t.Fatalf("Failed to scan age, got %v", err) + } + + if age != 10 { + t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) + } + + table := "gorm.users" + if DB.Dialector.Name() != "mysql" || isTiDB() { + table = "users" // other databases doesn't support select with `database.table` + } + + DB.Table(table).Where(map[string]interface{}{"name": user2.Name}).Update("age", 20) + + row = DB.Table(table+" as u").Where("u.name = ?", user2.Name).Select("age").Row() + if err := row.Scan(&age); err != nil { + t.Fatalf("Failed to scan age, got %v", err) + } + + if age != 20 { + t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) + } +} + +func TestRows(t *testing.T) { + user1 := User{Name: "RowsUser1", Age: 1} + user2 := User{Name: "RowsUser2", Age: 10} + user3 := User{Name: "RowsUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + count := 0 + for rows.Next() { + var name string + var age int64 + rows.Scan(&name, &age) + count++ + } + + if count != 2 { + t.Errorf("Should found two records") + } +} + +func TestRaw(t *testing.T) { + user1 := User{Name: "ExecRawSqlUser1", Age: 1} + user2 := User{Name: "ExecRawSqlUser2", Age: 10} + user3 := User{Name: "ExecRawSqlUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + Name string + Email string + } + + var results []result + DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&results) + if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { + t.Errorf("Raw with scan") + } + + rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() + count := 0 + for rows.Next() { + count++ + } + if count != 1 { + t.Errorf("Raw with Rows should find one record with name 3") + } + + DB.Exec("update users set name=? where name in (?)", "jinzhu-raw", []string{user1.Name, user2.Name, user3.Name}) + if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { + t.Error("Raw sql to update records") + } + + DB.Exec("update users set age=? where name = ?", gorm.Expr("age * ? + ?", 2, 10), "jinzhu-raw") + + var age int + DB.Raw("select sum(age) from users where name = ?", "jinzhu-raw").Scan(&age) + + if age != ((1+10+20)*2 + 30) { + t.Errorf("Invalid age, got %v", age) + } +} + +func TestRowsWithGroup(t *testing.T) { + users := []User{ + {Name: "having_user_1", Age: 1}, + {Name: "having_user_2", Age: 10}, + {Name: "having_user_1", Age: 20}, + {Name: "having_user_1", Age: 30}, + } + + DB.Create(&users) + + rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN ?", []string{users[0].Name, users[1].Name}).Rows() + if err != nil { + t.Fatalf("got error %v", err) + } + + defer rows.Close() + for rows.Next() { + var name string + var total int64 + rows.Scan(&name, &total) + + if name == users[0].Name && total != 3 { + t.Errorf("Should have one user having name %v", users[0].Name) + } else if name == users[1].Name && total != 1 { + t.Errorf("Should have two users having name %v", users[1].Name) + } + } +} + +func TestQueryRaw(t *testing.T) { + users := []*User{ + GetUser("row_query_user", Config{}), + GetUser("row_query_user", Config{}), + GetUser("row_query_user", Config{}), + } + DB.Create(&users) + + var user User + DB.Raw("select * from users WHERE id = ?", users[1].ID).First(&user) + CheckUser(t, user, *users[1]) +} + +func TestDryRun(t *testing.T) { + user := *GetUser("dry-run", Config{}) + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&user).Statement + if stmt.SQL.String() == "" || len(stmt.Vars) != 9 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + stmt2 := dryRunDB.Find(&user, "id = ?", user.ID).Statement + if stmt2.SQL.String() == "" || len(stmt2.Vars) != 1 { + t.Errorf("Failed to generate sql, got %v", stmt2.SQL.String()) + } +} + +type ageInt int8 + +func (ageInt) String() string { + return "age" +} + +type ageBool bool + +func (ageBool) String() string { + return "age" +} + +type ageUint64 uint64 + +func (ageUint64) String() string { + return "age" +} + +type ageFloat float64 + +func (ageFloat) String() string { + return "age" +} + +func TestExplainSQL(t *testing.T) { + user := *GetUser("explain-sql", Config{}) + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement + sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } +} + +func TestGroupConditions(t *testing.T) { + type Pizza struct { + ID uint + Name string + Size string + } + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Where( + DB.Where("pizza = ?", "pepperoni").Where(DB.Where("size = ?", "small").Or("size = ?", "medium")), + ).Or( + DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), + ).Find(&Pizza{}).Statement + + execStmt := dryRunDB.Exec("WHERE (pizza = ? AND (size = ? OR size = ?)) OR (pizza = ? AND size = ?)", "pepperoni", "small", "medium", "hawaiian", "xlarge").Statement + + result := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + expects := DB.Dialector.Explain(execStmt.SQL.String(), execStmt.Vars...) + + if !strings.HasSuffix(result, expects) { + t.Errorf("expects: %v, got %v", expects, result) + } + + stmt2 := dryRunDB.Where( + DB.Scopes(NameIn1And2), + ).Or( + DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), + ).Find(&Pizza{}).Statement + + execStmt2 := dryRunDB.Exec(`WHERE name in ? OR (pizza = ? AND size = ?)`, []string{"ScopeUser1", "ScopeUser2"}, "hawaiian", "xlarge").Statement + + result2 := DB.Dialector.Explain(stmt2.SQL.String(), stmt2.Vars...) + expects2 := DB.Dialector.Explain(execStmt2.SQL.String(), execStmt2.Vars...) + + if !strings.HasSuffix(result2, expects2) { + t.Errorf("expects: %v, got %v", expects2, result2) + } +} + +func TestCombineStringConditions(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR c = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } +} + +func TestFromWithJoins(t *testing.T) { + var result User + + newDB := DB.Session(&gorm.Session{NewDB: true, DryRun: true}).Table("users") + + newDB.Clauses( + clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Table: clause.Table{Name: "companies", Raw: false}, + ON: clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{ + Table: "users", + Name: "company_id", + }, + Value: clause.Column{ + Table: "companies", + Name: "id", + }, + }, + }, + }, + }, + }, + }, + ) + + newDB.Joins("inner join rgs on rgs.id = user.id") + + stmt := newDB.First(&result).Statement + str := stmt.SQL.String() + + if !strings.Contains(str, "rgs.id = user.id") { + t.Errorf("The second join condition is over written instead of combining") + } + + if !strings.Contains(str, "`users`.`company_id` = `companies`.`id`") && !strings.Contains(str, "\"users\".\"company_id\" = \"companies\".\"id\"") { + t.Errorf("The first join condition is over written instead of combining") + } +} + +func TestToSQL(t *testing.T) { + // By default DB.DryRun should false + if DB.DryRun { + t.Fatal("Failed expect DB.DryRun to be false") + } + + if DB.Dialector.Name() == "sqlserver" { + t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") + } + + date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local) + + // find + sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).Limit(10).Order("age desc").Find(&[]User{}) + }) + assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql) + + // after model changed + if DB.Statement.DryRun || DB.DryRun { + t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") + } + + if DB.Statement.SQL.String() != "" { + t.Fatal("Failed expect DB.Statement.SQL to be empty") + } + + // first + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{}) + }) + assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'foo' AND "users"."age" = 20 AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) + + // last and unscoped + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Unscoped().Where(&User{Name: "bar", Age: 12}).Limit(10).Offset(5).Order("name ASC").Last(&User{}) + }) + assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'bar' AND "users"."age" = 12 ORDER BY name ASC,"users"."id" DESC LIMIT 1 OFFSET 5`, sql) + + // create + user := &User{Name: "foo", Age: 20} + user.CreatedAt = date + user.UpdatedAt = date + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Create(user) + }) + assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql) + + // save + user = &User{Name: "foo", Age: 20} + user.CreatedAt = date + user.UpdatedAt = date + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Save(user) + }) + assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql) + + // updates + user = &User{Name: "bar", Age: 22} + user.CreatedAt = date + user.UpdatedAt = date + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).Updates(user) + }) + assertEqualSQL(t, `UPDATE "users" SET "created_at"='2021-10-18 00:00:00',"updated_at"='2021-10-18 19:50:09.438',"name"='bar',"age"=22 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // update + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).Update("name", "Foo bar") + }) + assertEqualSQL(t, `UPDATE "users" SET "name"='Foo bar',"updated_at"='2021-10-18 19:50:09.438' WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // UpdateColumn + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).UpdateColumn("name", "Foo bar") + }) + assertEqualSQL(t, `UPDATE "users" SET "name"='Foo bar' WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // UpdateColumns + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).UpdateColumns(User{Name: "Foo", Age: 100}) + }) + assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // after model changed + if DB.Statement.DryRun || DB.DryRun { + t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") + } + + // UpdateColumns + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Raw("SELECT * FROM users ?", clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}}, + }) + }) + assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql) +} + +// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. +func assertEqualSQL(t *testing.T, expected string, actually string) { + t.Helper() + + // replace SQL quote, convert into postgresql like "" + expected = replaceQuoteInSQL(expected) + actually = replaceQuoteInSQL(actually) + + // ignore updated_at value, because it's generated in Gorm internal, can't to mock value on update. + updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`) + actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`) + expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`) + + // ignore RETURNING "id" (only in PostgreSQL) + returningRe := regexp.MustCompile(`(?i)RETURNING "id"`) + actually = returningRe.ReplaceAllString(actually, ``) + expected = returningRe.ReplaceAllString(expected, ``) + + actually = strings.TrimSpace(actually) + expected = strings.TrimSpace(expected) + + if actually != expected { + t.Fatalf("\nexpected: %s\nactually: %s", expected, actually) + } +} + +func replaceQuoteInSQL(sql string) string { + // convert single quote into double quote + sql = strings.ReplaceAll(sql, `'`, `"`) + + // convert dialect special quote into double quote + switch DB.Dialector.Name() { + case "postgres": + sql = strings.ReplaceAll(sql, `"`, `"`) + case "mysql", "sqlite": + sql = strings.ReplaceAll(sql, "`", `"`) + case "sqlserver": + sql = strings.ReplaceAll(sql, `'`, `"`) + } + + return sql +} diff --git a/tests/table_test.go b/tests/table_test.go new file mode 100644 index 00000000..fa569d32 --- /dev/null +++ b/tests/table_test.go @@ -0,0 +1,174 @@ +package tests_test + +import ( + "regexp" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" + . "gorm.io/gorm/utils/tests" +) + +type UserWithTable struct { + gorm.Model + Name string +} + +func (UserWithTable) TableName() string { + return "gorm.user" +} + +func TestTable(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + r := dryDB.Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("`people`").Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("people as p").Table("user as u").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("people as p").Table("user").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.people").Table("user").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("name").Find(&UserWithTable{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Where("name = ?", 1).Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name").Where("name = ?", 2), DB.Model(&Pet{}).Where("name = ?", 4).Select("name")).Where("name = ?", 3).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE name = .+ AND .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE name = .+ AND .pets.\\..deleted_at. IS NULL\\) as p WHERE name = .+ AND name = .+ AND .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) +} + +func TestTableWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* " + + r := dryDB.Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile(userQuery + "FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("name").Find(&UserWithTable{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + + userQueryCharacter := "SELECT .*u.*id.*u.*created_at.*u.*updated_at.*u.*deleted_at.*u.*name.*u.*age.*u.*birthday" + + ".*u.*company_id.*u.*manager_id.*u.*active.* " + + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Where("name = ?", 1).Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name").Where("name = ?", 2), DB.Model(&Pet{}).Where("name = ?", 4).Select("name")).Where("name = ?", 3).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE name = .+ AND .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE name = .+ AND .pets.\\..deleted_at. IS NULL\\) as p WHERE name = .+ AND name = .+ AND .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) +} + +type UserWithTableNamer struct { + gorm.Model + Name string +} + +func (UserWithTableNamer) TableName(namer schema.Namer) string { + return namer.TableName("user") +} + +func TestTableWithNamer(t *testing.T) { + db, _ := gorm.Open(tests.DummyDialector{}, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: "t_", + }, + }) + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{}) + }) + + if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) { + t.Errorf("Table with namer, got %v", sql) + } +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh new file mode 100755 index 00000000..ee9e7675 --- /dev/null +++ b/tests/tests_all.sh @@ -0,0 +1,61 @@ +#!/bin/bash -e + +dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb") + +if [[ $(pwd) == *"gorm/tests"* ]]; then + cd .. +fi + +if [ -d tests ] +then + cd tests + go get -u -t ./... + go mod download + go mod tidy + cd .. +fi + +# SqlServer for Mac M1 +if [[ -z $GITHUB_ACTION ]]; then + if [ -d tests ] + then + cd tests + if [[ $(uname -a) == *" arm64" ]]; then + MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true + go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true + else + docker-compose start + fi + cd .. + fi +fi + + +for dialect in "${dialects[@]}" ; do + if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] + then + echo "testing ${dialect}..." + + if [ "$GORM_VERBOSE" = "" ] + then + GORM_DIALECT=${dialect} go test -race -count=1 ./... + if [ -d tests ] + then + cd tests + GORM_DIALECT=${dialect} go test -race -count=1 ./... + cd .. + fi + else + GORM_DIALECT=${dialect} go test -race -count=1 -v ./... + if [ -d tests ] + then + cd tests + GORM_DIALECT=${dialect} go test -race -count=1 -v ./... + cd .. + fi + fi + fi +done diff --git a/tests/tests_test.go b/tests/tests_test.go new file mode 100644 index 00000000..90eb847f --- /dev/null +++ b/tests/tests_test.go @@ -0,0 +1,132 @@ +package tests_test + +import ( + "log" + "math/rand" + "os" + "path/filepath" + "time" + + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + "gorm.io/gorm/logger" + . "gorm.io/gorm/utils/tests" +) + +var DB *gorm.DB +var ( + mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" + sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" +) + +func init() { + var err error + if DB, err = OpenTestConnection(); err != nil { + log.Printf("failed to connect database, got error %v", err) + os.Exit(1) + } else { + sqlDB, err := DB.DB() + if err != nil { + log.Printf("failed to connect database, got error %v", err) + os.Exit(1) + } + + err = sqlDB.Ping() + if err != nil { + log.Printf("failed to ping sqlDB, got error %v", err) + os.Exit(1) + } + + RunMigrations() + if DB.Dialector.Name() == "sqlite" { + DB.Exec("PRAGMA foreign_keys = ON") + } + } +} + +func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") + switch os.Getenv("GORM_DIALECT") { + case "mysql": + log.Println("testing mysql...") + if dbDSN == "" { + dbDSN = mysqlDSN + } + db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + case "postgres": + log.Println("testing postgres...") + if dbDSN == "" { + dbDSN = postgresDSN + } + db, err = gorm.Open(postgres.New(postgres.Config{ + DSN: dbDSN, + PreferSimpleProtocol: true, + }), &gorm.Config{}) + case "sqlserver": + // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest + // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 + // CREATE DATABASE gorm; + // GO + // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; + // CREATE USER gorm FROM LOGIN gorm; + // ALTER SERVER ROLE sysadmin ADD MEMBER [gorm]; + // GO + log.Println("testing sqlserver...") + if dbDSN == "" { + dbDSN = sqlserverDSN + } + db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) + case "tidb": + log.Println("testing tidb...") + if dbDSN == "" { + dbDSN = tidbDSN + } + db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + default: + log.Println("testing sqlite3...") + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + } + + if err != nil { + return + } + + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger = db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger = db.Logger.LogMode(logger.Silent) + } + + return +} + +func RunMigrations() { + var err error + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + + DB.Migrator().DropTable("user_friends", "user_speaks") + + if err = DB.Migrator().DropTable(allModels...); err != nil { + log.Printf("Failed to drop table, got error %v\n", err) + os.Exit(1) + } + + if err = DB.AutoMigrate(allModels...); err != nil { + log.Printf("Failed to auto migrate, but got error %v\n", err) + os.Exit(1) + } + + for _, m := range allModels { + if !DB.Migrator().HasTable(m) { + log.Printf("Failed to create table for %#v\n", m) + os.Exit(1) + } + } +} diff --git a/tests/tracer_test.go b/tests/tracer_test.go new file mode 100644 index 00000000..3e9a4052 --- /dev/null +++ b/tests/tracer_test.go @@ -0,0 +1,34 @@ +package tests_test + +import ( + "context" + "time" + + "gorm.io/gorm/logger" +) + +type Tracer struct { + Logger logger.Interface + Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) +} + +func (S Tracer) LogMode(level logger.LogLevel) logger.Interface { + return S.Logger.LogMode(level) +} + +func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) { + S.Logger.Info(ctx, s, i...) +} + +func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) { + S.Logger.Warn(ctx, s, i...) +} + +func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) { + S.Logger.Error(ctx, s, i...) +} + +func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + S.Logger.Trace(ctx, begin, fc, err) + S.Test(ctx, begin, fc, err) +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go new file mode 100644 index 00000000..5872da94 --- /dev/null +++ b/tests/transaction_test.go @@ -0,0 +1,399 @@ +package tests_test + +import ( + "context" + "errors" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestTransaction(t *testing.T) { + tx := DB.Begin() + user := *GetUser("transaction", Config{}) + + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err := tx.First(&User{}, "name = ?", "transaction").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + user1 := *GetUser("transaction1-1", Config{}) + + if err := tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { + t.Fatalf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err := DB.First(&User{}, "name = ?", "transaction").Error; err == nil { + t.Fatalf("Should not find record after rollback, but got %v", err) + } + + txDB := DB.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() + user2 := *GetUser("transaction-2", Config{}) + if err := tx2.Save(&user2).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err := tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should be able to find committed record, but got %v", err) + } +} + +func TestCancelTransaction(t *testing.T) { + ctx := context.Background() + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + user := *GetUser("cancel_transaction", Config{}) + DB.Create(&user) + + err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var result User + tx.First(&result, user.ID) + return nil + }) + + if err == nil { + t.Fatalf("Transaction should get error when using cancelled context") + } +} + +func TestTransactionWithBlock(t *testing.T) { + assertPanic := func(f func()) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("The code did not panic") + } + }() + f() + } + + // rollback + err := DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transaction-block", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("the error message") + }) + + if err != nil && err.Error() != "the error message" { + t.Fatalf("Transaction return error will equal the block returns error") + } + + if err := DB.First(&User{}, "name = ?", "transaction-block").Error; err == nil { + t.Fatalf("Should not find record after rollback") + } + + // commit + DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transaction-block-2", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }) + + if err := DB.First(&User{}, "name = ?", "transaction-block-2").Error; err != nil { + t.Fatalf("Should be able to find committed record") + } + + // panic will rollback + assertPanic(func() { + DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transaction-block-3", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + panic("force panic") + }) + }) + + if err := DB.First(&User{}, "name = ?", "transaction-block-3").Error; err == nil { + t.Fatalf("Should not find record after panic rollback") + } +} + +func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { + tx := DB.Begin() + user := User{Name: "transaction"} + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise") + } + + if err := tx.Commit().Error; err != nil { + t.Fatalf("Commit should not raise error") + } + + if err := tx.Rollback().Error; err == nil { + t.Fatalf("Rollback after commit should raise error") + } +} + +func TestTransactionWithSavePoint(t *testing.T) { + tx := DB.Begin() + + user := *GetUser("transaction-save-point", Config{}) + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.SavePoint("save_point1").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + user1 := *GetUser("transaction-save-point-1", Config{}) + tx.Create(&user1) + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.RollbackTo("save_point1").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.SavePoint("save_point2").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + user2 := *GetUser("transaction-save-point-2", Config{}) + tx.Create(&user2) + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Commit().Error; err != nil { + t.Fatalf("Failed to commit, got error %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + +func TestNestedTransactionWithBlock(t *testing.T) { + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Transaction(func(tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return nil + }); err != nil { + t.Fatalf("nested transaction returns error: %v", err) + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + +func TestDisabledNestedTransaction(t *testing.T) { + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := DB.Session(&gorm.Session{DisableNestedTransaction: true}).Transaction(func(tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Transaction(func(tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := tx.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return nil + }); err != nil { + t.Fatalf("nested transaction returns error: %v", err) + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + +func TestTransactionOnClosedConn(t *testing.T) { + DB, err := OpenTestConnection() + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + rawDB, _ := DB.DB() + rawDB.Close() + + if err := DB.Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } + + if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } +} + +func TestTransactionWithHooks(t *testing.T) { + user := GetUser("tTestTransactionWithHooks", Config{Account: true}) + DB.Create(&user) + + var err error + err = DB.Transaction(func(tx *gorm.DB) error { + return tx.Model(&User{}).Limit(1).Transaction(func(tx2 *gorm.DB) error { + return tx2.Scan(&User{}).Error + }) + }) + + if err != nil { + t.Error(err) + } + + // method with hooks + err = DB.Transaction(func(tx1 *gorm.DB) error { + // callMethod do + tx2 := tx1.Find(&User{}).Session(&gorm.Session{NewDB: true}) + // trx in hooks + return tx2.Transaction(func(tx3 *gorm.DB) error { + return tx3.Where("user_id", user.ID).Delete(&Account{}).Error + }) + }) + + if err != nil { + t.Error(err) + } +} diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go new file mode 100644 index 00000000..4e94cfd5 --- /dev/null +++ b/tests/update_belongs_to_test.go @@ -0,0 +1,59 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdateBelongsTo(t *testing.T) { + user := *GetUser("update-belongs-to", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Company = Company{Name: "company-belongs-to-association"} + user.Manager = &User{Name: "manager-belongs-to-association"} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + user.Company.Name += "new" + user.Manager.Name += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + + user.Company.Name += "new2" + user.Manager.Name += "new2" + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Select("`Company`").Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user5 User + DB.Preload("Company").Preload("Manager").Find(&user5, "id = ?", user.ID) + if user5.Manager.Name != user4.Manager.Name { + t.Errorf("should not update user's manager") + } else { + user.Manager.Name = user4.Manager.Name + } + CheckUser(t, user, user5) +} diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go new file mode 100644 index 00000000..2ca93e2b --- /dev/null +++ b/tests/update_has_many_test.go @@ -0,0 +1,82 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdateHasManyAssociations(t *testing.T) { + user := *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Pets").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + for _, pet := range user.Pets { + pet.Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Pets").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Pets").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + + t.Run("Polymorphic", func(t *testing.T) { + user := *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Toys").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + for idx := range user.Toys { + user.Toys[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Toys").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Toys").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + }) +} diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go new file mode 100644 index 00000000..40af6ae7 --- /dev/null +++ b/tests/update_has_one_test.go @@ -0,0 +1,137 @@ +package tests_test + +import ( + "database/sql" + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdateHasOne(t *testing.T) { + user := *GetUser("update-has-one", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Account = Account{Number: "account-has-one-association"} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Account").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + user.Account.Number += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Account").Find(&user3, "id = ?", user.ID) + + CheckUser(t, user2, user3) + lastUpdatedAt := user2.Account.UpdatedAt + time.Sleep(time.Second) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Account").Find(&user4, "id = ?", user.ID) + + if lastUpdatedAt.Format(time.RFC3339) == user4.Account.UpdatedAt.Format(time.RFC3339) { + t.Fatalf("updated at should be updated, but not, old: %v, new %v", lastUpdatedAt.Format(time.RFC3339), user3.Account.UpdatedAt.Format(time.RFC3339)) + } else { + user.Account.UpdatedAt = user4.Account.UpdatedAt + CheckUser(t, user4, user) + } + + t.Run("Polymorphic", func(t *testing.T) { + pet := Pet{Name: "create"} + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} + + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var pet2 Pet + DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) + CheckPet(t, pet2, pet) + + pet.Toy.Name += "new" + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet3 Pet + DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID) + CheckPet(t, pet2, pet3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet4 Pet + DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) + CheckPet(t, pet4, pet) + }) + + t.Run("Restriction", func(t *testing.T) { + type CustomizeAccount struct { + gorm.Model + UserID sql.NullInt64 + Number string `gorm:"<-:create"` + Number2 string + } + + type CustomizeUser struct { + gorm.Model + Name string + Account CustomizeAccount `gorm:"foreignkey:UserID"` + } + + DB.Migrator().DropTable(&CustomizeUser{}) + DB.Migrator().DropTable(&CustomizeAccount{}) + + if err := DB.AutoMigrate(&CustomizeUser{}); err != nil { + t.Fatalf("failed to migrate, got error: %v", err) + } + if err := DB.AutoMigrate(&CustomizeAccount{}); err != nil { + t.Fatalf("failed to migrate, got error: %v", err) + } + + number := "number-has-one-associations" + cusUser := CustomizeUser{ + Name: "update-has-one-associations", + Account: CustomizeAccount{ + Number: number, + Number2: number, + }, + } + + if err := DB.Create(&cusUser).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + cusUser.Account.Number += "-update" + cusUser.Account.Number2 += "-update" + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var account2 CustomizeAccount + DB.Find(&account2, "user_id = ?", cusUser.ID) + AssertEqual(t, account2.Number, number) + AssertEqual(t, account2.Number2, cusUser.Account.Number2) + }) +} diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go new file mode 100644 index 00000000..f1218cc0 --- /dev/null +++ b/tests/update_many2many_test.go @@ -0,0 +1,54 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdateMany2ManyAssociations(t *testing.T) { + user := *GetUser("update-many2many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} + for _, lang := range user.Languages { + DB.Create(&lang) + } + user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + for idx := range user.Friends { + user.Friends[idx].Name += "new" + } + + for idx := range user.Languages { + user.Languages[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) +} diff --git a/tests/update_test.go b/tests/update_test.go new file mode 100644 index 00000000..36ffa6a0 --- /dev/null +++ b/tests/update_test.go @@ -0,0 +1,807 @@ +package tests_test + +import ( + "errors" + "regexp" + "sort" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdate(t *testing.T) { + var ( + users = []*User{ + GetUser("update-1", Config{}), + GetUser("update-2", Config{}), + GetUser("update-3", Config{}), + } + user = users[1] + lastUpdatedAt time.Time + ) + + checkUpdatedAtChanged := func(name string, n time.Time) { + if n.UnixNano() == lastUpdatedAt.UnixNano() { + t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) + } + lastUpdatedAt = n + } + + checkOtherData := func(name string) { + var first, last User + if err := DB.Where("id = ?", users[0].ID).First(&first).Error; err != nil { + t.Errorf("errors happened when query before user: %v", err) + } + CheckUser(t, first, *users[0]) + + if err := DB.Where("id = ?", users[2].ID).First(&last).Error; err != nil { + t.Errorf("errors happened when query after user: %v", err) + } + CheckUser(t, last, *users[2]) + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } else if user.ID == 0 { + t.Fatalf("user's primary value should not zero, %v", user.ID) + } else if user.UpdatedAt.IsZero() { + t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) + } + lastUpdatedAt = user.UpdatedAt + + if err := DB.Model(user).Update("Age", 10).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 10 { + t.Errorf("Age should equals to 10, but got %v", user.Age) + } + checkUpdatedAtChanged("Update", user.UpdatedAt) + checkOtherData("Update") + + var result User + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result, *user) + } + + values := map[string]interface{}{"Active": true, "age": 5} + if res := DB.Model(user).Updates(values); res.Error != nil { + t.Errorf("errors happened when update: %v", res.Error) + } else if res.RowsAffected != 1 { + t.Errorf("rows affected should be 1, but got : %v", res.RowsAffected) + } else if user.Age != 5 { + t.Errorf("Age should equals to 5, but got %v", user.Age) + } else if user.Active != true { + t.Errorf("Active should be true, but got %v", user.Active) + } + checkUpdatedAtChanged("Updates with map", user.UpdatedAt) + checkOtherData("Updates with map") + + var result2 User + if err := DB.Where("id = ?", user.ID).First(&result2).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result2, *user) + } + + if err := DB.Model(user).Updates(User{Age: 2}).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 2 { + t.Errorf("Age should equals to 2, but got %v", user.Age) + } + checkUpdatedAtChanged("Updates with struct", user.UpdatedAt) + checkOtherData("Updates with struct") + + var result3 User + if err := DB.Where("id = ?", user.ID).First(&result3).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result3, *user) + } + + user.Active = false + user.Age = 1 + if err := DB.Save(user).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 1 { + t.Errorf("Age should equals to 1, but got %v", user.Age) + } else if user.Active != false { + t.Errorf("Active should equals to false, but got %v", user.Active) + } + checkUpdatedAtChanged("Save", user.UpdatedAt) + checkOtherData("Save") + + var result4 User + if err := DB.Where("id = ?", user.ID).First(&result4).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result4, *user) + } + + if rowsAffected := DB.Model([]User{result4}).Where("age > 0").Update("name", "jinzhu").RowsAffected; rowsAffected != 1 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } + + if rowsAffected := DB.Model(users).Where("age > 0").Update("name", "jinzhu").RowsAffected; rowsAffected != 3 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } +} + +func TestUpdates(t *testing.T) { + users := []*User{ + GetUser("updates_01", Config{}), + GetUser("updates_02", Config{}), + } + + DB.Create(&users) + lastUpdatedAt := users[0].UpdatedAt + + // update with map + if res := DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100}); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("Failed to update users") + } + + if users[0].Name != "updates_01_newname" || users[0].Age != 100 { + t.Errorf("Record should be updated also with map") + } + + if users[0].UpdatedAt.UnixNano() == lastUpdatedAt.UnixNano() { + t.Errorf("User's updated at should be changed, but got %v, was %v", users[0].UpdatedAt.UnixNano(), lastUpdatedAt) + } + + // user2 should not be updated + var user1, user2 User + DB.First(&user1, users[0].ID) + DB.First(&user2, users[1].ID) + CheckUser(t, user1, *users[0]) + CheckUser(t, user2, *users[1]) + + // update with struct + time.Sleep(1 * time.Second) + DB.Table("users").Where("name in ?", []string{users[1].Name}).Updates(User{Name: "updates_02_newname"}) + + var user3 User + if err := DB.First(&user3, "name = ?", "updates_02_newname").Error; err != nil { + t.Errorf("User2's name should be updated") + } + + if user2.UpdatedAt.Format(time.RFC1123Z) == user3.UpdatedAt.Format(time.RFC1123Z) { + t.Errorf("User's updated at should be changed, old %v, new %v", user2.UpdatedAt.Format(time.RFC1123Z), user3.UpdatedAt.Format(time.RFC1123Z)) + } + + // update with gorm exprs + if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + var user4 User + DB.First(&user4, user3.ID) + + user3.Age += 100 + AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") +} + +func TestUpdateColumn(t *testing.T) { + users := []*User{ + GetUser("update_column_01", Config{}), + GetUser("update_column_02", Config{}), + } + + DB.Create(&users) + lastUpdatedAt := users[1].UpdatedAt + + // update with map + DB.Model(users[1]).UpdateColumns(map[string]interface{}{"name": "update_column_02_newname", "age": 100}) + if users[1].Name != "update_column_02_newname" || users[1].Age != 100 { + t.Errorf("user 2 should be updated with update column") + } + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + // user2 should not be updated + var user1, user2 User + DB.First(&user1, users[0].ID) + DB.First(&user2, users[1].ID) + CheckUser(t, user1, *users[0]) + CheckUser(t, user2, *users[1]) + + DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew") + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + if users[1].Name != "update_column_02_newnew" { + t.Errorf("user 2's name should be updated, but got %v", users[1].Name) + } + + DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) + var user3 User + DB.First(&user3, users[1].ID) + + users[1].Age += 50 + CheckUser(t, user3, *users[1]) + + // update with struct + DB.Model(users[1]).UpdateColumns(User{Name: "update_column_02_newnew2", Age: 200}) + if users[1].Name != "update_column_02_newnew2" || users[1].Age != 200 { + t.Errorf("user 2 should be updated with update column") + } + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + // user2 should not be updated + var user5, user6 User + DB.First(&user5, users[0].ID) + DB.First(&user6, users[1].ID) + CheckUser(t, user5, *users[0]) + CheckUser(t, user6, *users[1]) +} + +func TestBlockGlobalUpdate(t *testing.T) { + if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) + } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } +} + +func TestSelectWithUpdate(t *testing.T) { + user := *GetUser("select_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("select_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + result.Name = user2.Name + result.Age = 50 + result.Account = user2.Account + result.Pets = user2.Pets + result.Toys = user2.Toys + result.Company = user2.Company + result.Manager = user2.Manager + result.Team = user2.Team + result.Languages = user2.Languages + result.Friends = user2.Friends + + DB.Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Languages = append(user.Languages, result.Languages...) + result.Toys = append(user.Toys, result.Toys...) + + sort.Slice(result.Languages, func(i, j int) bool { + return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 + }) + + sort.Slice(result.Toys, func(i, j int) bool { + return result.Toys[i].ID < result.Toys[j].ID + }) + + sort.Slice(result2.Languages, func(i, j int) bool { + return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 + }) + + sort.Slice(result2.Toys, func(i, j int) bool { + return result2.Toys[i].ID < result2.Toys[j].ID + }) + + AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") + + DB.Model(&result).Select("Name", "Age").Updates(User{Name: "update_with_select"}) + if result.Age != 0 || result.Name != "update_with_select" { + t.Fatalf("Failed to update struct with select, got %+v", result) + } + AssertObjEqual(t, result, user, "UpdatedAt") + + var result3 User + DB.First(&result3, result.ID) + AssertObjEqual(t, result, result3, "Name", "Age", "UpdatedAt") + + DB.Model(&result).Select("Name", "Age", "UpdatedAt").Updates(User{Name: "update_with_select"}) + + if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) { + t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt) + } + + AssertObjEqual(t, result, User{Name: "update_with_select"}, "Name", "Age") +} + +func TestSelectWithUpdateWithMap(t *testing.T) { + user := *GetUser("select_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("select_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + updateValues := map[string]interface{}{ + "Name": user2.Name, + "Age": 50, + "Account": user2.Account, + "Pets": user2.Pets, + "Toys": user2.Toys, + "Company": user2.Company, + "Manager": user2.Manager, + "Team": user2.Team, + "Languages": user2.Languages, + "Friends": user2.Friends, + } + + DB.Model(&result).Omit("name", "updated_at").Updates(updateValues) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Languages = append(user.Languages, result.Languages...) + result.Toys = append(user.Toys, result.Toys...) + + sort.Slice(result.Languages, func(i, j int) bool { + return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 + }) + + sort.Slice(result.Toys, func(i, j int) bool { + return result.Toys[i].ID < result.Toys[j].ID + }) + + sort.Slice(result2.Languages, func(i, j int) bool { + return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 + }) + + sort.Slice(result2.Toys, func(i, j int) bool { + return result2.Toys[i].ID < result2.Toys[j].ID + }) + + AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") +} + +func TestWithUpdateWithInvalidMap(t *testing.T) { + user := *GetUser("update_with_invalid_map", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error for unsupported updating data") + } +} + +func TestOmitWithUpdate(t *testing.T) { + user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("omit_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + result.Name = user2.Name + result.Age = 50 + result.Account = user2.Account + result.Pets = user2.Pets + result.Toys = user2.Toys + result.Company = user2.Company + result.Manager = user2.Manager + result.Team = user2.Team + result.Languages = user2.Languages + result.Friends = user2.Friends + + DB.Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Pets = append(user.Pets, result.Pets...) + result.Team = append(user.Team, result.Team...) + result.Friends = append(user.Friends, result.Friends...) + + sort.Slice(result.Pets, func(i, j int) bool { + return result.Pets[i].ID < result.Pets[j].ID + }) + sort.Slice(result.Team, func(i, j int) bool { + return result.Team[i].ID < result.Team[j].ID + }) + sort.Slice(result.Friends, func(i, j int) bool { + return result.Friends[i].ID < result.Friends[j].ID + }) + sort.Slice(result2.Pets, func(i, j int) bool { + return result2.Pets[i].ID < result2.Pets[j].ID + }) + sort.Slice(result2.Team, func(i, j int) bool { + return result2.Team[i].ID < result2.Team[j].ID + }) + sort.Slice(result2.Friends, func(i, j int) bool { + return result2.Friends[i].ID < result2.Friends[j].ID + }) + + AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") +} + +func TestOmitWithUpdateWithMap(t *testing.T) { + user := *GetUser("omit_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("omit_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + updateValues := map[string]interface{}{ + "Name": user2.Name, + "Age": 50, + "Account": user2.Account, + "Pets": user2.Pets, + "Toys": user2.Toys, + "Company": user2.Company, + "Manager": user2.Manager, + "Team": user2.Team, + "Languages": user2.Languages, + "Friends": user2.Friends, + } + + DB.Model(&result).Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Pets = append(user.Pets, result.Pets...) + result.Team = append(user.Team, result.Team...) + result.Friends = append(user.Friends, result.Friends...) + + sort.Slice(result.Pets, func(i, j int) bool { + return result.Pets[i].ID < result.Pets[j].ID + }) + sort.Slice(result.Team, func(i, j int) bool { + return result.Team[i].ID < result.Team[j].ID + }) + sort.Slice(result.Friends, func(i, j int) bool { + return result.Friends[i].ID < result.Friends[j].ID + }) + sort.Slice(result2.Pets, func(i, j int) bool { + return result2.Pets[i].ID < result2.Pets[j].ID + }) + sort.Slice(result2.Team, func(i, j int) bool { + return result2.Team[i].ID < result2.Team[j].ID + }) + sort.Slice(result2.Friends, func(i, j int) bool { + return result2.Friends[i].ID < result2.Friends[j].ID + }) + + AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") +} + +func TestSelectWithUpdateColumn(t *testing.T) { + user := *GetUser("select_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var result User + DB.First(&result, user.ID) + + time.Sleep(time.Second) + lastUpdatedAt := result.UpdatedAt + DB.Model(&result).Select("Name").Updates(updateValues) + + var result2 User + DB.First(&result2, user.ID) + + if lastUpdatedAt.Format(time.RFC3339Nano) == result2.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdatedAt should be changed") + } + + if result2.Name == user.Name || result2.Age != user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestOmitWithUpdateColumn(t *testing.T) { + user := *GetUser("omit_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var result User + DB.First(&result, user.ID) + DB.Model(&result).Omit("Name").UpdateColumns(updateValues) + + var result2 User + DB.First(&result2, user.ID) + + if result2.Name != user.Name || result2.Age == user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestUpdateColumnsSkipsAssociations(t *testing.T) { + user := *GetUser("update_column_skips_association", Config{}) + DB.Create(&user) + + // Update a single field of the user and verify that the changed address is not stored. + newAge := uint(100) + user.Account.Number = "new_account_number" + db := DB.Model(&user).UpdateColumns(User{Age: newAge}) + + if db.RowsAffected != 1 { + t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", db.RowsAffected) + } + + // Verify that Age now=`newAge`. + result := &User{} + result.ID = user.ID + DB.Preload("Account").First(result) + + if result.Age != newAge { + t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, result.Age) + } + + if result.Account.Number != user.Account.Number { + t.Errorf("account number should not been changed, expects: %v, got %v", user.Account.Number, result.Account.Number) + } +} + +func TestUpdatesWithBlankValues(t *testing.T) { + user := *GetUser("updates_with_blank_value", Config{}) + DB.Save(&user) + + var user2 User + user2.ID = user.ID + DB.Model(&user2).Updates(&User{Age: 100}) + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Name || result.Age != 100 { + t.Errorf("user's name should not be updated") + } +} + +func TestUpdatesTableWithIgnoredValues(t *testing.T) { + type ElementWithIgnoredField struct { + Id int64 + Value string + IgnoredField int64 `gorm:"-"` + } + DB.Migrator().DropTable(&ElementWithIgnoredField{}) + DB.AutoMigrate(&ElementWithIgnoredField{}) + + elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} + DB.Save(&elem) + + DB.Model(&ElementWithIgnoredField{}). + Where("id = ?", elem.Id). + Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) + + var result ElementWithIgnoredField + if err := DB.First(&result, elem.Id).Error; err != nil { + t.Errorf("error getting an element from database: %s", err.Error()) + } + + if result.IgnoredField != 0 { + t.Errorf("element's ignored field should not be updated") + } +} + +func TestUpdateFromSubQuery(t *testing.T) { + user := *GetUser("update_from_sub_query", Config{Company: true}) + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error: %v", err) + } + + if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Company.Name { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } + + DB.Model(&user.Company).Update("Name", "new company name") + if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + DB.First(&result, user.ID) + if result.Name != "new company name" { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } +} + +func TestIdempotentSave(t *testing.T) { + create := Company{ + Name: "company_idempotent", + } + DB.Create(&create) + + var company Company + if err := DB.Find(&company, "id = ?", create.ID).Error; err != nil { + t.Fatalf("failed to find created company, got err: %v", err) + } + + if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { + t.Errorf("failed to save company, got err: %v", err) + } + if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { + t.Errorf("failed to save company, got err: %v", err) + } +} + +func TestSave(t *testing.T) { + user := *GetUser("save", Config{}) + DB.Create(&user) + + if err := DB.First(&User{}, "name = ?", "save").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user.Name = "save2" + DB.Save(&user) + + var result User + if err := DB.First(&result, "name = ?", "save2").Error; err != nil || result.ID != user.ID { + t.Fatalf("failed to find updated user") + } + + user2 := *GetUser("save2", Config{}) + DB.Create(&user2) + + time.Sleep(time.Second) + user1UpdatedAt := result.UpdatedAt + user2UpdatedAt := user2.UpdatedAt + users := []*User{&result, &user2} + DB.Save(&users) + + if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) + } + + if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) + } + + DB.First(&result) + if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed after reload, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) + } + + DB.First(&user2) + if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user2's updated at should be changed after reload, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Save(&user).Statement + if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) + } + + dryDB = DB.Session(&gorm.Session{DryRun: true}) + stmt = dryDB.Unscoped().Save(&user).Statement + if !regexp.MustCompile(`WHERE .id. = [^ ]+$`).MatchString(stmt.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) + } + + user3 := *GetUser("save3", Config{}) + DB.Create(&user3) + + if err := DB.First(&User{}, "name = ?", "save3").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user3.Name = "save3_" + if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil { + t.Fatalf("failed to save user, got %v", err) + } + + var result2 User + if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { + t.Fatalf("failed to find updated user, got %v", err) + } + + if err := DB.Model(User{Model: user3.Model}).Save(&struct { + gorm.Model + Placeholder string + Name string + }{ + Model: user3.Model, + Placeholder: "placeholder", + Name: "save3__", + }).Error; err != nil { + t.Fatalf("failed to update user, got %v", err) + } + + var result3 User + if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { + t.Fatalf("failed to find updated user") + } +} + +func TestSaveWithPrimaryValue(t *testing.T) { + lang := Language{Code: "save", Name: "save"} + if result := DB.Save(&lang); result.RowsAffected != 1 { + t.Errorf("should create language, rows affected: %v", result.RowsAffected) + } + + var result Language + DB.First(&result, "code = ?", "save") + AssertEqual(t, result, lang) + + lang.Name = "save name2" + if result := DB.Save(&lang); result.RowsAffected != 1 { + t.Errorf("should update language") + } + + var result2 Language + DB.First(&result2, "code = ?", "save") + AssertEqual(t, result2, lang) + + DB.Table("langs").Migrator().DropTable(&Language{}) + DB.Table("langs").AutoMigrate(&Language{}) + + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result3 Language + if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3) + } + + lang.Name += "name2" + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result4 Language + if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) + } +} + +// only sqlite, postgres support returning +func TestUpdateReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + users := []*User{ + GetUser("update-returning-1", Config{}), + GetUser("update-returning-2", Config{}), + GetUser("update-returning-3", Config{}), + } + DB.Create(&users) + + var results []User + DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88) + if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 { + t.Errorf("failed to return updated data, got %v", results) + } + + if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + + if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + + if results[1].Age-results[0].Age != 100 { + t.Errorf("failed to return updated age column") + } +} + +func TestUpdateWithDiffSchema(t *testing.T) { + user := GetUser("update-diff-schema-1", Config{}) + DB.Create(&user) + + type UserTemp struct { + Name string + } + + err := DB.Model(&user).Updates(&UserTemp{Name: "update-diff-schema-2"}).Error + AssertEqual(t, err, nil) + AssertEqual(t, "update-diff-schema-2", user.Name) +} diff --git a/tests/upsert_test.go b/tests/upsert_test.go new file mode 100644 index 00000000..e84dc14a --- /dev/null +++ b/tests/upsert_test.go @@ -0,0 +1,328 @@ +package tests_test + +import ( + "regexp" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestUpsert(t *testing.T) { + lang := Language{Code: "upsert", Name: "Upsert"} + if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + lang2 := Language{Code: "upsert", Name: "Upsert"} + if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + var langs []Language + if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } + + lang3 := Language{Code: "upsert", Name: "Upsert"} + if err := DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), + }).Create(&lang3).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if langs[0].Name != "upsert-new" { + t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) + } + + lang = Language{Code: "upsert", Name: "Upsert-Newname"} + if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + var result Language + if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { + t.Fatalf("failed to upsert, got name %v", result.Name) + } + + if name := DB.Dialector.Name(); name != "sqlserver" { + type RestrictedLanguage struct { + Code string `gorm:"primarykey"` + Name string + Lang string `gorm:"<-:create"` + } + + r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + + user := *GetUser("upsert_on_conflict", Config{}) + user.Age = 20 + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error %v", err) + } + + var user2 User + DB.First(&user2, user.ID) + user2.Age = 30 + time.Sleep(time.Second) + if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil { + t.Fatalf("failed to onconflict create user, got error %v", err) + } else { + var user3 User + DB.First(&user3, user.ID) + if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() { + t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt) + } + } +} + +func TestUpsertSlice(t *testing.T) { + langs := []Language{ + {Code: "upsert-slice1", Name: "Upsert-slice1"}, + {Code: "upsert-slice2", Name: "Upsert-slice2"}, + {Code: "upsert-slice3", Name: "Upsert-slice3"}, + } + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + + var langs2 []Language + if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs2) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs2) + } + + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + var langs3 []Language + if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs3) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs3) + } + + for idx, lang := range langs { + lang.Name = lang.Name + "_new" + langs[idx] = lang + } + + if err := DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.AssignmentColumns([]string{"name"}), + }).Create(&langs).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + for _, lang := range langs { + var results []Language + if err := DB.Find(&results, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(results) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if results[0].Name != lang.Name { + t.Errorf("should update name on conflict, but got name %+v", results[0].Name) + } + } +} + +func TestUpsertWithSave(t *testing.T) { + langs := []Language{ + {Code: "upsert-save-1", Name: "Upsert-save-1"}, + {Code: "upsert-save-2", Name: "Upsert-save-2"}, + } + + if err := DB.Save(&langs).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + for _, lang := range langs { + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } + } + + for idx, lang := range langs { + lang.Name += "_new" + langs[idx] = lang + } + + if err := DB.Save(&langs).Error; err != nil { + t.Errorf("Failed to upsert, got error %v", err) + } + + for _, lang := range langs { + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } + } + + lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} + if err := DB.Save(&lang).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } + + lang.Name += "_new" + if err := DB.Save(&lang).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + var result2 Language + if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result2, lang) + } +} + +func TestFindOrInitialize(t *testing.T) { + var user1, user2, user3, user4, user5, user6 User + if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { + t.Errorf("no error should happen when FirstOrInit, but got %v", err) + } + + if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) + if user2.Name != "find or init" || user2.ID != 0 || user2.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) + if user3.Name != "find or init 2" || user3.ID != 0 { + t.Errorf("user should be initialized with inline search value") + } + + DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and attrs") + } + + DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and assign attrs") + } + + DB.Save(&User{Name: "find or init", Age: 33}) + DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or init" || user5.ID == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 33 { + t.Errorf("user should be found with FirstOrInit") + } + + DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } +} + +func TestFindOrCreate(t *testing.T) { + var user1, user2, user3, user4, user5, user6, user7, user8 User + if err := DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1).Error; err != nil { + t.Errorf("no error should happen when FirstOrInit, but got %v", err) + } + + if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) + if user1.ID != user2.ID || user2.Name != "find or create" || user2.ID == 0 || user2.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) + if user3.Name != "find or create 2" || user3.ID == 0 { + t.Errorf("user should be created with inline search value") + } + + DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) + if user4.Name != "find or create 3" || user4.ID == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and attrs") + } + + updatedAt1 := user4.UpdatedAt + DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) + + if user4.Age != 55 { + t.Errorf("Failed to set change to 55, got %v", user4.Age) + } + + if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdateAt should be changed when update values with assign") + } + + DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) + if user4.Name != "find or create 4" || user4.ID == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or create" || user5.ID == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) + if user6.Name != "find or create" || user6.ID == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Find(&user7) + if user7.Name != "find or create" || user7.ID == 0 || user7.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, Account: Account{Number: "1231231231"}, Pets: []*Pet{{Name: "first_or_create_pet1"}, {Name: "first_or_create_pet2"}}}).FirstOrCreate(&user8) + if err := DB.Where("name = ?", "first_or_create_pet1").First(&Pet{}).Error; err != nil { + t.Errorf("has many association should be saved") + } + + if err := DB.Where("number = ?", "1231231231").First(&Account{}).Error; err != nil { + t.Errorf("belongs to association should be saved") + } +} + +func TestUpdateWithMissWhere(t *testing.T) { + type User struct { + ID uint `gorm:"column:id;<-:create"` + Name string `gorm:"column:name"` + } + user := User{ID: 1, Name: "king"} + tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user) + + if err := tx.Error; err != nil { + t.Fatalf("failed to update user,missing where condition,err=%+v", err) + } + + if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", tx.Statement.SQL.String()) + } +} diff --git a/update_test.go b/update_test.go deleted file mode 100644 index 3ce64ce3..00000000 --- a/update_test.go +++ /dev/null @@ -1,465 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -func TestUpdate(t *testing.T) { - product1 := Product{Code: "product1code"} - product2 := Product{Code: "product2code"} - - DB.Save(&product1).Save(&product2).Update("code", "product2newcode") - - if product2.Code != "product2newcode" { - t.Errorf("Record should be updated") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt1 := product1.UpdatedAt - - if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { - t.Errorf("Product1 should not be updated") - } - - if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode") - - var product4 Product - DB.First(&product4, product1.Id) - if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() { - t.Errorf("Product1's code should be updated") - } - - if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() { - t.Errorf("Product should not be changed to 789") - } - - if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update with CamelCase") - } - - if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update_column with CamelCase") - } - - var products []Product - DB.Find(&products) - if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { - t.Error("RowsAffected should be correct when do batch update") - } - - DB.First(&product4, product4.Id) - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("Update with expression") - } - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Update with expression should update UpdatedAt") - } -} - -func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - DB.Save(&animal) - updatedAt1 := animal.UpdatedAt - - DB.Save(&animal).Update("name", "Francis") - - if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated if nothing changed") - } - - var animals []Animal - DB.Find(&animals) - if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { - t.Error("RowsAffected should be correct when do batch update") - } - - animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) - DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched - DB.First(&animal, animal.Counter) - if animal.Name != "galeone" { - t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name) - } - - // When changing a field with a default value, the change must occur - animal.Name = "amazing horse" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "amazing horse" { - t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) - } - - // When changing a field with a default value with blank value - animal.Name = "" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "" { - t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) - } -} - -func TestUpdates(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 10} - DB.Save(&product1).Save(&product2) - DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100}) - if product1.Code != "product1newcode" || product1.Price != 100 { - t.Errorf("Record should be updated also with map") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - - if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { - t.Errorf("Product2 should not be updated") - } - - if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() { - t.Errorf("Product1 should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"}) - if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("product2's code should be updated") - } - - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100 { - t.Errorf("Updates with expression") - } - // product4's UpdatedAt will be reset when updating - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Updates with expression should update UpdatedAt") - } -} - -func TestUpdateColumn(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 20} - DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100}) - if product2.Code != "product2newcode" || product2.Price != 100 { - t.Errorf("product 2 should be updated with update column") - } - - var product3 Product - DB.First(&product3, product1.Id) - if product3.Code != "product1code" || product3.Price != 10 { - t.Errorf("product 1 should not be updated") - } - - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - DB.Model(product2).UpdateColumn("code", "update_column_new") - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated with update column") - } - - DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50")) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("UpdateColumn with expression") - } - if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateColumn with expression should not update UpdatedAt") - } -} - -func TestSelectWithUpdate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestSelectWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestOmitWithUpdate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships that not omited") - } -} - -func TestOmitWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships not omited") - } -} - -func TestSelectWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } -} - -func TestOmitWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should omit name column when update user") - } -} - -func TestUpdateColumnsSkipsAssociations(t *testing.T) { - user := getPreparedUser("update_columns_user", "special_role") - user.Age = 99 - address1 := "first street" - user.BillingAddress = Address{Address1: address1} - DB.Save(user) - - // Update a single field of the user and verify that the changed address is not stored. - newAge := int64(100) - user.BillingAddress.Address1 = "second street" - db := DB.Model(user).UpdateColumns(User{Age: newAge}) - if db.RowsAffected != 1 { - t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) - } - - // Verify that Age now=`newAge`. - freshUser := &User{Id: user.Id} - DB.First(freshUser) - if freshUser.Age != newAge { - t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) - } - - // Verify that user's BillingAddress.Address1 is not changed and is still "first street". - DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) - if freshUser.BillingAddress.Address1 != address1 { - t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) - } -} - -func TestUpdatesWithBlankValues(t *testing.T) { - product := Product{Code: "product1", Price: 10} - DB.Save(&product) - - DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100}) - - var product1 Product - DB.First(&product1, product.Id) - - if product1.Code != "product1" || product1.Price != 100 { - t.Errorf("product's code should not be updated") - } -} - -type ElementWithIgnoredField struct { - Id int64 - Value string - IgnoredField int64 `sql:"-"` -} - -func (e ElementWithIgnoredField) TableName() string { - return "element_with_ignored_field" -} - -func TestUpdatesTableWithIgnoredValues(t *testing.T) { - elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} - DB.Save(&elem) - - DB.Table(elem.TableName()). - Where("id = ?", elem.Id). - // DB.Model(&ElementWithIgnoredField{Id: elem.Id}). - Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) - - var elem1 ElementWithIgnoredField - err := DB.First(&elem1, elem.Id).Error - if err != nil { - t.Errorf("error getting an element from database: %s", err.Error()) - } - - if elem1.IgnoredField != 0 { - t.Errorf("element's ignored field should not be updated") - } -} - -func TestUpdateDecodeVirtualAttributes(t *testing.T) { - var user = User{ - Name: "jinzhu", - IgnoreMe: 88, - } - - DB.Save(&user) - - DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100}) - - if user.IgnoreMe != 100 { - t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks") - } -} diff --git a/utils.go b/utils.go deleted file mode 100644 index ba1f08ab..00000000 --- a/utils.go +++ /dev/null @@ -1,264 +0,0 @@ -package gorm - -import ( - "bytes" - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" - "sync" - "time" -) - -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs, e.g: -// gorm.NowFunc = func() time.Time { -// return time.Now().UTC() -// } -var NowFunc = func() time.Time { - return time.Now() -} - -// Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} -var commonInitialismsReplacer *strings.Replacer - -func init() { - var commonInitialismsForReplacer []string - for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) - } - commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) -} - -type safeMap struct { - m map[string]string - l *sync.RWMutex -} - -func (s *safeMap) Set(key string, value string) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeMap) Get(key string) string { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newSafeMap() *safeMap { - return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} -} - -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - if i > 0 { - if currCase == upper { - if lastCase == upper && nextCase == upper { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - -// SQL expression -type expr struct { - expr string - args []interface{} -} - -// Expr generate raw SQL expression, for example: -// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *expr { - return &expr{expr: expression, args: args} -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for _,_ = range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(values [][]interface{}) (results []interface{}) { - for _, value := range values { - for _, v := range value { - results = append(results, v) - } - } - return -} - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func addExtraSpaceIfExist(str string) string { - if str != "" { - return " " + str - } - return "" -} diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go new file mode 100644 index 00000000..a2d9c33d --- /dev/null +++ b/utils/tests/dummy_dialecter.go @@ -0,0 +1,100 @@ +package tests + +import ( + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +type DummyDialector struct { + TranslatedErr error +} + +func (DummyDialector) Name() string { + return "dummy" +} + +func (DummyDialector) Initialize(db *gorm.DB) error { + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, + LastInsertIDReversed: true, + }) + + return nil +} + +func (DummyDialector) DefaultValueOf(field *schema.Field) clause.Expression { + return clause.Expr{SQL: "DEFAULT"} +} + +func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { + return nil +} + +func (DummyDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('?') +} + +func (DummyDialector) QuoteTo(writer clause.Writer, str string) { + var ( + underQuoted, selfQuoted bool + continuousBacktick int8 + shiftDelimiter int8 + ) + + for _, v := range []byte(str) { + switch v { + case '`': + continuousBacktick++ + if continuousBacktick == 2 { + writer.WriteString("``") + continuousBacktick = 0 + } + case '.': + if continuousBacktick > 0 || !selfQuoted { + shiftDelimiter = 0 + underQuoted = false + continuousBacktick = 0 + writer.WriteByte('`') + } + writer.WriteByte(v) + continue + default: + if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { + writer.WriteByte('`') + underQuoted = true + if selfQuoted = continuousBacktick > 0; selfQuoted { + continuousBacktick -= 1 + } + } + + for ; continuousBacktick > 0; continuousBacktick -= 1 { + writer.WriteString("``") + } + + writer.WriteByte(v) + } + shiftDelimiter++ + } + + if continuousBacktick > 0 && !selfQuoted { + writer.WriteString("``") + } + writer.WriteByte('`') +} + +func (DummyDialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + +func (DummyDialector) DataTypeOf(*schema.Field) string { + return "" +} + +func (d DummyDialector) Translate(err error) error { + return d.TranslatedErr +} diff --git a/utils/tests/models.go b/utils/tests/models.go new file mode 100644 index 00000000..ec1651a3 --- /dev/null +++ b/utils/tests/models.go @@ -0,0 +1,96 @@ +package tests + +import ( + "database/sql" + "time" + + "gorm.io/gorm" +) + +// User has one `Account` (has one), many `Pets` (has many) and `Toys` (has many - polymorphic) +// He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) +// He speaks many languages (many to many) and has many friends (many to many - single-table) +// His pet also has one Toy (has one - polymorphic) +// NamedPet is a reference to a Named `Pets` (has many) +type User struct { + gorm.Model + Name string + Age uint + Birthday *time.Time + Account Account + Pets []*Pet + NamedPet *Pet + Toys []Toy `gorm:"polymorphic:Owner"` + CompanyID *int + Company Company + ManagerID *uint + Manager *User + Team []User `gorm:"foreignkey:ManagerID"` + Languages []Language `gorm:"many2many:UserSpeak;"` + Friends []*User `gorm:"many2many:user_friends;"` + Active bool +} + +type Account struct { + gorm.Model + UserID sql.NullInt64 + Number string +} + +type Pet struct { + gorm.Model + UserID *uint + Name string + Toy Toy `gorm:"polymorphic:Owner;"` +} + +type Toy struct { + gorm.Model + Name string + OwnerID string + OwnerType string +} + +type Company struct { + ID int + Name string +} + +type Language struct { + Code string `gorm:"primarykey"` + Name string +} + +type Coupon struct { + ID int `gorm:"primarykey; size:255"` + AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` + AmountOff uint32 `gorm:"column:amount_off"` + PercentOff float32 `gorm:"column:percent_off"` +} + +type CouponProduct struct { + CouponId int `gorm:"primarykey;size:255"` + ProductId string `gorm:"primarykey;size:255"` + Desc string +} + +type Order struct { + gorm.Model + Num string + Coupon *Coupon + CouponID string +} + +type Parent struct { + gorm.Model + FavChildID uint + FavChild *Child + Children []*Child +} + +type Child struct { + gorm.Model + Name string + ParentID *uint + Parent *Parent +} diff --git a/utils/tests/utils.go b/utils/tests/utils.go new file mode 100644 index 00000000..49d01f2e --- /dev/null +++ b/utils/tests/utils.go @@ -0,0 +1,128 @@ +package tests + +import ( + "database/sql/driver" + "fmt" + "go/ast" + "reflect" + "testing" + "time" + + "gorm.io/gorm/utils" +) + +func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { + for _, name := range names { + rv := reflect.Indirect(reflect.ValueOf(r)) + ev := reflect.Indirect(reflect.ValueOf(e)) + if rv.IsValid() != ev.IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e) + return + } + got := rv.FieldByName(name).Interface() + expect := ev.FieldByName(name).Interface() + t.Run(name, func(t *testing.T) { + AssertEqual(t, got, expect) + }) + } +} + +func AssertEqual(t *testing.T, got, expect interface{}) { + if !reflect.DeepEqual(got, expect) { + isEqual := func() { + if curTime, ok := got.(time.Time); ok { + format := "2006-01-02T15:04:05Z07:00" + + if curTime.Round(time.Second).UTC().Format(format) != expect.(time.Time).Round(time.Second).UTC().Format(format) && curTime.Truncate(time.Second).UTC().Format(format) != expect.(time.Time).Truncate(time.Second).UTC().Format(format) { + t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) + } + } else if fmt.Sprint(got) != fmt.Sprint(expect) { + t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) + } + } + + if fmt.Sprint(got) == fmt.Sprint(expect) { + return + } + + if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + + if valuer, ok := got.(driver.Valuer); ok { + got, _ = valuer.Value() + } + + if valuer, ok := expect.(driver.Valuer); ok { + expect, _ = valuer.Value() + } + + if got != nil { + got = reflect.Indirect(reflect.ValueOf(got)).Interface() + } + + if expect != nil { + expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() + } + + if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + + if reflect.ValueOf(got).Kind() == reflect.Slice { + if reflect.ValueOf(expect).Kind() == reflect.Slice { + if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { + for i := 0; i < reflect.ValueOf(got).Len(); i++ { + name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) + t.Run(name, func(t *testing.T) { + AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) + }) + } + } else { + name := reflect.ValueOf(got).Type().Elem().Name() + t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got) + } + return + } + } + + if reflect.ValueOf(got).Kind() == reflect.Struct { + if reflect.ValueOf(expect).Kind() == reflect.Struct { + if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + exported := false + for i := 0; i < reflect.ValueOf(got).NumField(); i++ { + if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + exported = true + field := reflect.ValueOf(got).Field(i) + t.Run(fieldStruct.Name, func(t *testing.T) { + AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) + }) + } + } + + if exported { + return + } + } + } + } + + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { + got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() + isEqual() + } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { + expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() + isEqual() + } else { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + } +} + +func Now() *time.Time { + now := time.Now() + return &now +} diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 00000000..ddbca60a --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,150 @@ +package utils + +import ( + "database/sql/driver" + "fmt" + "path/filepath" + "reflect" + "runtime" + "strconv" + "strings" + "unicode" +) + +var gormSourceDir string + +func init() { + _, file, _, _ := runtime.Caller(0) + // compatible solution to get gorm source directory with various operating systems + gormSourceDir = sourceDir(file) +} + +func sourceDir(file string) string { + dir := filepath.Dir(file) + dir = filepath.Dir(dir) + + s := filepath.Dir(dir) + if filepath.Base(s) != "gorm.io" { + s = dir + } + return filepath.ToSlash(s) + "/" +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + // the second caller usually from gorm internal, so set i start from 2 + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { + return file + ":" + strconv.FormatInt(int64(line), 10) + } + } + + return "" +} + +func IsValidDBNameChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' +} + +// CheckTruth check string true or not +func CheckTruth(vals ...string) bool { + for _, val := range vals { + if val != "" && !strings.EqualFold(val, "false") { + return true + } + } + return false +} + +func ToStringKey(values ...interface{}) string { + results := make([]string, len(values)) + + for idx, value := range values { + if valuer, ok := value.(driver.Valuer); ok { + value, _ = valuer.Value() + } + + switch v := value.(type) { + case string: + results[idx] = v + case []byte: + results[idx] = string(v) + case uint: + results[idx] = strconv.FormatUint(uint64(v), 10) + default: + results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) + } + } + + return strings.Join(results, "_") +} + +func Contains(elems []string, elem string) bool { + for _, e := range elems { + if elem == e { + return true + } + } + return false +} + +func AssertEqual(src, dst interface{}) bool { + if !reflect.DeepEqual(src, dst) { + if valuer, ok := src.(driver.Valuer); ok { + src, _ = valuer.Value() + } + + if valuer, ok := dst.(driver.Valuer); ok { + dst, _ = valuer.Value() + } + + return reflect.DeepEqual(src, dst) + } + return true +} + +func ToString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case int: + return strconv.FormatInt(int64(v), 10) + case int8: + return strconv.FormatInt(int64(v), 10) + case int16: + return strconv.FormatInt(int64(v), 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.FormatUint(uint64(v), 10) + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint16: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + } + return "" +} + +const nestedRelationSplit = "__" + +// NestedRelationName nested relationships like `Manager__Company` +func NestedRelationName(prefix, name string) string { + return prefix + nestedRelationSplit + name +} + +// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}` +func SplitNestedRelationName(name string) []string { + return strings.Split(name, nestedRelationSplit) +} + +// JoinNestedRelationNames nested relationships like `Manager__Company` +func JoinNestedRelationNames(relationNames []string) string { + return strings.Join(relationNames, nestedRelationSplit) +} diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 00000000..71eef964 --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,137 @@ +package utils + +import ( + "database/sql" + "database/sql/driver" + "errors" + "math" + "strings" + "testing" + "time" +) + +func TestIsValidDBNameChar(t *testing.T) { + for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} { + if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 { + t.Fatalf("failed to parse db name %v", db) + } + } +} + +func TestCheckTruth(t *testing.T) { + checkTruthTests := []struct { + v string + out bool + }{ + {"123", true}, + {"true", true}, + {"", false}, + {"false", false}, + {"False", false}, + {"FALSE", false}, + {"\u0046alse", false}, + } + + for _, test := range checkTruthTests { + t.Run(test.v, func(t *testing.T) { + if out := CheckTruth(test.v); out != test.out { + t.Errorf("CheckTruth(%s) want: %t, got: %t", test.v, test.out, out) + } + }) + } +} + +func TestToStringKey(t *testing.T) { + cases := []struct { + values []interface{} + key string + }{ + {[]interface{}{"a"}, "a"}, + {[]interface{}{1, 2, 3}, "1_2_3"}, + {[]interface{}{[]interface{}{1, 2, 3}}, "[1 2 3]"}, + {[]interface{}{[]interface{}{"1", "2", "3"}}, "[1 2 3]"}, + } + for _, c := range cases { + if key := ToStringKey(c.values...); key != c.key { + t.Errorf("%v: expected %v, got %v", c.values, c.key, key) + } + } +} + +func TestContains(t *testing.T) { + containsTests := []struct { + name string + elems []string + elem string + out bool + }{ + {"exists", []string{"1", "2", "3"}, "1", true}, + {"not exists", []string{"1", "2", "3"}, "4", false}, + } + for _, test := range containsTests { + t.Run(test.name, func(t *testing.T) { + if out := Contains(test.elems, test.elem); test.out != out { + t.Errorf("Contains(%v, %s) want: %t, got: %t", test.elems, test.elem, test.out, out) + } + }) + } +} + +type ModifyAt sql.NullTime + +// Value return a Unix time. +func (n ModifyAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time.Unix(), nil +} + +func TestAssertEqual(t *testing.T) { + now := time.Now() + assertEqualTests := []struct { + name string + src, dst interface{} + out bool + }{ + {"error equal", errors.New("1"), errors.New("1"), true}, + {"error not equal", errors.New("1"), errors.New("2"), false}, + {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, + {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, + } + for _, test := range assertEqualTests { + t.Run(test.name, func(t *testing.T) { + if out := AssertEqual(test.src, test.dst); test.out != out { + t.Errorf("AssertEqual(%v, %v) want: %t, got: %t", test.src, test.dst, test.out, out) + } + }) + } +} + +func TestToString(t *testing.T) { + tests := []struct { + name string + in interface{} + out string + }{ + {"int", math.MaxInt64, "9223372036854775807"}, + {"int8", int8(math.MaxInt8), "127"}, + {"int16", int16(math.MaxInt16), "32767"}, + {"int32", int32(math.MaxInt32), "2147483647"}, + {"int64", int64(math.MaxInt64), "9223372036854775807"}, + {"uint", uint(math.MaxUint64), "18446744073709551615"}, + {"uint8", uint8(math.MaxUint8), "255"}, + {"uint16", uint16(math.MaxUint16), "65535"}, + {"uint32", uint32(math.MaxUint32), "4294967295"}, + {"uint64", uint64(math.MaxUint64), "18446744073709551615"}, + {"string", "abc", "abc"}, + {"other", true, ""}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if out := ToString(test.in); test.out != out { + t.Fatalf("ToString(%v) want: %s, got: %s", test.in, test.out, out) + } + }) + } +} diff --git a/utils/utils_unix_test.go b/utils/utils_unix_test.go new file mode 100644 index 00000000..450cbe2a --- /dev/null +++ b/utils/utils_unix_test.go @@ -0,0 +1,38 @@ +//go:build unix +// +build unix + +package utils + +import ( + "testing" +) + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: "/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/Users/name/go/pkg/mod/gorm.io/", + }, + { + file: "/go/work/proj/gorm/utils/utils.go", + want: "/go/work/proj/gorm/", + }, + { + file: "/go/work/proj/gorm_alias/utils/utils.go", + want: "/go/work/proj/gorm_alias/", + }, + { + file: "/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/go/work/proj/my.gorm.io/gorm@v1.2.3/", + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +} diff --git a/utils/utils_windows_test.go b/utils/utils_windows_test.go new file mode 100644 index 00000000..8b1c519d --- /dev/null +++ b/utils/utils_windows_test.go @@ -0,0 +1,35 @@ +package utils + +import ( + "testing" +) + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: `C:/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/Users/name/go/pkg/mod/gorm.io/`, + }, + { + file: `C:/go/work/proj/gorm/utils/utils.go`, + want: `C:/go/work/proj/gorm/`, + }, + { + file: `C:/go/work/proj/gorm_alias/utils/utils.go`, + want: `C:/go/work/proj/gorm_alias/`, + }, + { + file: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/`, + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +} diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index 07f5b17f..00000000 --- a/utils_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -} diff --git a/wercker.yml b/wercker.yml deleted file mode 100644 index ff6fb17c..00000000 --- a/wercker.yml +++ /dev/null @@ -1,53 +0,0 @@ -# use the default golang container from Docker Hub -box: golang - -services: - - id: mariadb:10.0 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - id: postgres - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - -# The steps that will be executed in the build pipeline -build: - # The steps that will be executed on build - steps: - # Sets the go workspace and places you package - # at the right place in the workspace tree - - setup-go-workspace - - # Gets the dependencies - - script: - name: go get - code: | - cd $WERCKER_SOURCE_DIR - go version - go get -t ./... - - # Build the project - - script: - name: go build - code: | - go build ./... - - # Test the project - - script: - name: test sqlite - code: | - go test ./... - - - script: - name: test mysql - code: | - GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./... - - - script: - name: test postgres - code: | - GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./...