Skip to content

Commit

Permalink
database: ensure that concurrent vulnerability/feature versions inser…
Browse files Browse the repository at this point in the history
…tions work fine
  • Loading branch information
Quentin-M authored and jzelinskie committed Feb 24, 2016
1 parent 74fc5b3 commit bd17dfb
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 18 deletions.
146 changes: 146 additions & 0 deletions database/pgsql/complex_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package pgsql

import (
"fmt"
"math/rand"
"runtime"
"strconv"
"sync"
"testing"
"time"

"github.com/coreos/clair/database"
"github.com/coreos/clair/utils"
"github.com/coreos/clair/utils/types"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
)

const (
numVulnerabilities = 100
numFeatureVersions = 100
)

func TestRaceAffects(t *testing.T) {
datastore, err := OpenForTest("TestRaceAffects", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()

// Insert the Feature on which we'll work.
feature := database.Feature{
Namespace: database.Namespace{Name: "TestRaceAffectsFeatureNamespace1"},
Name: "TestRaceAffecturesFeature1",
}
_, err = datastore.insertFeature(feature)
if err != nil {
t.Error(err)
return
}

// Initialize random generator and enforce max procs.
rand.Seed(time.Now().UnixNano())
runtime.GOMAXPROCS(runtime.NumCPU())

// Generate FeatureVersions.
featureVersions := make([]database.FeatureVersion, numFeatureVersions)
for i := 0; i < numFeatureVersions; i++ {
version := rand.Intn(numFeatureVersions)

featureVersions[i] = database.FeatureVersion{
Feature: feature,
Version: types.NewVersionUnsafe(strconv.Itoa(version)),
}
}

// Generate vulnerabilities.
// They are mapped by fixed version, which will make verification really easy afterwards.
vulnerabilities := make(map[int][]database.Vulnerability)
for i := 0; i < numVulnerabilities; i++ {
version := rand.Intn(numFeatureVersions) + 1

// if _, ok := vulnerabilities[version]; !ok {
// vulnerabilities[version] = make([]database.Vulnerability)
// }

vulnerability := database.Vulnerability{
Name: uuid.New(),
Namespace: feature.Namespace,
FixedIn: []database.FeatureVersion{
database.FeatureVersion{
Feature: feature,
Version: types.NewVersionUnsafe(strconv.Itoa(version)),
},
},
Severity: types.Unknown,
}

vulnerabilities[version] = append(vulnerabilities[version], vulnerability)
}

// Insert featureversions and vulnerabilities in parallel.
var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
for _, vulnerabilitiesM := range vulnerabilities {
for _, vulnerability := range vulnerabilitiesM {
err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability})
assert.Nil(t, err)
}
}
fmt.Println("finished to insert vulnerabilities")
}()

go func() {
defer wg.Done()
for i := 0; i < len(featureVersions); i++ {
featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i])
assert.Nil(t, err)
}
fmt.Println("finished to insert featureVersions")
}()

wg.Wait()

// Verify consistency now.
var actualAffectedNames []string
var expectedAffectedNames []string

for _, featureVersion := range featureVersions {
featureVersionVersion, _ := strconv.Atoi(featureVersion.Version.String())

// Get actual affects.
rows, err := datastore.Query(getQuery("s_complextest_featureversion_affects"),
featureVersion.ID)
assert.Nil(t, err)
defer rows.Close()

var vulnName string
for rows.Next() {
err = rows.Scan(&vulnName)
if !assert.Nil(t, err) {
continue
}
actualAffectedNames = append(actualAffectedNames, vulnName)
}
if assert.Nil(t, rows.Err()) {
rows.Close()
}

// Get expected affects.
for i := numVulnerabilities; i > featureVersionVersion; i-- {
for _, vulnerability := range vulnerabilities[i] {
expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name)
}
}

assert.Len(t, utils.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0)
assert.Len(t, utils.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0)
}

// TODO(Quentin-M): May be worth having a test for updates as well.
}
5 changes: 3 additions & 2 deletions database/pgsql/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
// Set transaction as SERIALIZABLE.
// This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always
// consistent.
_, err = tx.Exec(getQuery("set_tx_serializable"))
_, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion"))
if err != nil {
tx.Rollback()
return 0, handleError("insertFeatureVersion.set_tx_serializable", err)
return 0, handleError("insertFeatureVersion.l_vulnerability_affects_featureversion", err)
}

// Find or create FeatureVersion.
Expand Down Expand Up @@ -162,6 +162,7 @@ func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.Fea

// Insert into Vulnerability_Affects_FeatureVersion.
for _, affect := range affects {
// TODO(Quentin-M): Batch me.
_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), affect.vulnerabilityID,
featureVersion.ID, affect.fixedInID)
if err != nil {
Expand Down
10 changes: 9 additions & 1 deletion database/pgsql/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ var queries map[string]string
func init() {
queries = make(map[string]string)

queries["set_tx_serializable"] = `SET TRANSACTION ISOLATION LEVEL SERIALIZABLE`
queries["l_vulnerability_affects_featureversion"] = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE`

// keyvalue.go
queries["u_keyvalue"] = `UPDATE KeyValue SET value = $1 WHERE key = $2`
Expand Down Expand Up @@ -180,6 +180,14 @@ func init() {

queries["f_featureversion_by_feature"] = `
SELECT id, version FROM FeatureVersion WHERE feature_id = $1`

// complex_test.go
queries["s_complextest_featureversion_affects"] = `
SELECT v.name
FROM FeatureVersion fv
LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id
JOIN Vulnerability v ON vaf.vulnerability_id = v.id
WHERE featureversion_id = $1`
}

func getQuery(name string) string {
Expand Down
37 changes: 22 additions & 15 deletions database/pgsql/vulnerability.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er
// Set transaction as SERIALIZABLE.
// This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always
// consistent.
_, err = tx.Exec(getQuery("set_tx_serializable"))
_, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion"))
if err != nil {
tx.Rollback()
return handleError("insertFeatureVersion.set_tx_serializable", err)
return handleError("insertVulnerability.l_vulnerability_affects_featureversion", err)
}

if existingVulnerability.ID == 0 {
Expand Down Expand Up @@ -315,36 +315,43 @@ func (pgSQL *pgSQL) updateVulnerabilityFeatureVersions(tx *sql.Tx, vulnerability
}

func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, fixedInVersion types.Version) error {
// Find every FeatureVersions of the Feature we want to affect.
// Find every FeatureVersions of the Feature that the vulnerability affects.
// TODO(Quentin-M): LIMIT
rows, err := tx.Query(getQuery("f_featureversion_by_feature"), featureID)
if err == sql.ErrNoRows {
return nil
}
if err != nil {
return handleError("f_featureversion_by_feature", err)
}
defer rows.Close()

var featureVersionID int
var featureVersionVersion types.Version
var affecteds []database.FeatureVersion
for rows.Next() {
err := rows.Scan(&featureVersionID, &featureVersionVersion)
var affected database.FeatureVersion

err := rows.Scan(&affected.ID, &affected.Version)
if err != nil {
return handleError("f_featureversion_by_feature.Scan()", err)
}

if featureVersionVersion.Compare(fixedInVersion) < 0 {
_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID, featureVersionID,
fixedInID)
if err != nil {
return handleError("i_vulnerability_affects_featureversion", err)
}
if affected.Version.Compare(fixedInVersion) < 0 {
// The version of the FeatureVersion is lower than the fixed version of this vulnerability,
// thus, this FeatureVersion is affected by it.
affecteds = append(affecteds, affected)
}
}
if err = rows.Err(); err != nil {
return handleError("f_featureversion_by_feature.Rows()", err)
}
rows.Close()

// Insert into Vulnerability_Affects_FeatureVersion.
for _, affected := range affecteds {
// TODO(Quentin-M): Batch me.
_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID,
affected.ID, fixedInID)
if err != nil {
return handleError("i_vulnerability_affects_featureversion", err)
}
}

return nil
}

0 comments on commit bd17dfb

Please sign in to comment.