Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SQLite3] Use generic TLP Where oracle #942

Merged
merged 4 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/sqlancer/common/gen/PartitionGenerator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package sqlancer.common.gen;

import sqlancer.common.ast.newast.Expression;
import sqlancer.common.schema.AbstractTableColumn;

public interface PartitionGenerator<E extends Expression<C>, C extends AbstractTableColumn<?, ?>> {

/**
* Negates a predicate (i.e., uses a NOT operator).
*
* @param predicate
* the boolean predicate.
*
* @return the negated predicate.
*/
E negatePredicate(E predicate);

/**
* Checks if an expression evaluates to NULL (i.e., implements the IS NULL operator).
*
* @param expr
* the expression
*
* @return an expression that checks whether the expression evaluates to NULL.
*/
E isNull(E expr);
}
28 changes: 28 additions & 0 deletions src/sqlancer/common/gen/TLPWhereGenerator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package sqlancer.common.gen;

import java.util.List;

import sqlancer.common.ast.newast.Expression;
import sqlancer.common.ast.newast.Join;
import sqlancer.common.ast.newast.Select;
import sqlancer.common.schema.AbstractTable;
import sqlancer.common.schema.AbstractTableColumn;
import sqlancer.common.schema.AbstractTables;

public interface TLPWhereGenerator<S extends Select<J, E, T, C>, J extends Join<E, T, C>, E extends Expression<C>, T extends AbstractTable<C, ?, ?>, C extends AbstractTableColumn<?, ?>>
extends PartitionGenerator<E, C> {

TLPWhereGenerator<S, J, E, T, C> setTablesAndColumns(AbstractTables<T, C> tables);

E generateBooleanExpression();

S generateSelect();

List<J> getRandomJoinClauses();

List<E> getTableRefs();

List<E> generateFetchColumns(boolean shouldCreateDummy);

List<E> generateOrderBys();
}
84 changes: 84 additions & 0 deletions src/sqlancer/common/oracle/TLPWhereOracle.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package sqlancer.common.oracle;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

import sqlancer.ComparatorHelper;
import sqlancer.Randomly;
import sqlancer.SQLGlobalState;
import sqlancer.common.ast.newast.Expression;
import sqlancer.common.ast.newast.Join;
import sqlancer.common.ast.newast.Select;
import sqlancer.common.gen.TLPWhereGenerator;
import sqlancer.common.query.ExpectedErrors;
import sqlancer.common.schema.AbstractSchema;
import sqlancer.common.schema.AbstractTable;
import sqlancer.common.schema.AbstractTableColumn;
import sqlancer.common.schema.AbstractTables;

public class TLPWhereOracle<Z extends Select<J, E, T, C>, J extends Join<E, T, C>, E extends Expression<C>, S extends AbstractSchema<?, T>, T extends AbstractTable<C, ?, ?>, C extends AbstractTableColumn<?, ?>, G extends SQLGlobalState<?, S>>
implements TestOracle<G> {

private final G state;

private TLPWhereGenerator<Z, J, E, T, C> gen;
private final ExpectedErrors errors;

private String generatedQueryString;

public TLPWhereOracle(G state, TLPWhereGenerator<Z, J, E, T, C> gen, ExpectedErrors expectedErrors) {
if (state == null || gen == null || expectedErrors == null) {
throw new IllegalArgumentException("Null variables used to initialize test oracle.");
}
this.state = state;
this.gen = gen;
this.errors = expectedErrors;
}

@Override
public void check() throws SQLException {
S s = state.getSchema();
AbstractTables<T, C> targetTables = TestOracleUtils.getRandomTableNonEmptyTables(s);
gen = gen.setTablesAndColumns(targetTables);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but I was a bit confused by this. The set somewhat suggests that we set a value in the existing generator, but actually, we create a new generator. Can we reflect this in the naming of the methods? E.g., newWithTablesandColumns or copyWithTablesAndColumns or so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup i can do that after the other oracles have been created


Select<J, E, T, C> select = gen.generateSelect();

boolean shouldCreateDummy = true;
select.setFetchColumns(gen.generateFetchColumns(shouldCreateDummy));
select.setJoinClauses(gen.getRandomJoinClauses());
select.setFromList(gen.getTableRefs());
select.setWhereClause(null);

String originalQueryString = select.asString();
generatedQueryString = originalQueryString;
List<String> firstResultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors,
state);

boolean orderBy = Randomly.getBooleanWithSmallProbability();
if (orderBy) {
select.setOrderByClauses(gen.generateOrderBys());
}

TestOracleUtils.PredicateVariants<E, C> predicates = TestOracleUtils.initializeTernaryPredicateVariants(gen,
gen.generateBooleanExpression());
select.setWhereClause(predicates.predicate);
String firstQueryString = select.asString();
select.setWhereClause(predicates.negatedPredicate);
String secondQueryString = select.asString();
select.setWhereClause(predicates.isNullPredicate);
String thirdQueryString = select.asString();

List<String> combinedString = new ArrayList<>();
List<String> secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString,
thirdQueryString, combinedString, !orderBy, state, errors);

ComparatorHelper.assumeResultSetsAreEqual(firstResultSet, secondResultSet, originalQueryString, combinedString,
state);
}

@Override
public String getLastQueryString() {
return generatedQueryString;
}
}
33 changes: 33 additions & 0 deletions src/sqlancer/common/oracle/TestOracleUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import sqlancer.IgnoreMeException;
import sqlancer.Randomly;
import sqlancer.common.ast.newast.Expression;
import sqlancer.common.gen.PartitionGenerator;
import sqlancer.common.schema.AbstractSchema;
import sqlancer.common.schema.AbstractTable;
import sqlancer.common.schema.AbstractTableColumn;
Expand All @@ -12,11 +14,42 @@ public final class TestOracleUtils {
private TestOracleUtils() {
}

public static final class PredicateVariants<E extends Expression<C>, C extends AbstractTableColumn<?, ?>> {
public E predicate;
public E negatedPredicate;
public E isNullPredicate;

PredicateVariants(E predicate, E negatedPredicate, E isNullPredicate) {
this.predicate = predicate;
this.negatedPredicate = negatedPredicate;
this.isNullPredicate = isNullPredicate;
}
}

public static <T extends AbstractTable<C, ?, ?>, C extends AbstractTableColumn<?, ?>> AbstractTables<T, C> getRandomTableNonEmptyTables(
AbstractSchema<?, T> schema) {
if (schema.getDatabaseTables().isEmpty()) {
throw new IgnoreMeException();
}
return new AbstractTables<>(Randomly.nonEmptySubset(schema.getDatabaseTables()));
}

public static <E extends Expression<C>, T extends AbstractTable<C, ?, ?>, C extends AbstractTableColumn<?, ?>> PredicateVariants<E, C> initializeTernaryPredicateVariants(
PartitionGenerator<E, C> gen, E predicate) {
if (gen == null) {
throw new IllegalStateException();
}
if (predicate == null) {
throw new IllegalStateException();
}
E negatedPredicate = gen.negatePredicate(predicate);
if (negatedPredicate == null) {
throw new IllegalStateException();
}
E isNullPredicate = gen.isNull(predicate);
if (isNullPredicate == null) {
throw new IllegalStateException();
}
return new PredicateVariants<>(predicate, negatedPredicate, isNullPredicate);
}
}
17 changes: 16 additions & 1 deletion src/sqlancer/sqlite3/gen/SQLite3ExpressionGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sqlancer.Randomly;
import sqlancer.common.gen.ExpressionGenerator;
import sqlancer.common.gen.NoRECGenerator;
import sqlancer.common.gen.TLPWhereGenerator;
import sqlancer.common.schema.AbstractTables;
import sqlancer.sqlite3.SQLite3GlobalState;
import sqlancer.sqlite3.ast.SQLite3Aggregate;
Expand Down Expand Up @@ -50,7 +51,8 @@
import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Table;

public class SQLite3ExpressionGenerator implements ExpressionGenerator<SQLite3Expression>,
NoRECGenerator<SQLite3Select, Join, SQLite3Expression, SQLite3Table, SQLite3Column> {
NoRECGenerator<SQLite3Select, Join, SQLite3Expression, SQLite3Table, SQLite3Column>,
TLPWhereGenerator<SQLite3Select, Join, SQLite3Expression, SQLite3Table, SQLite3Column> {

private SQLite3RowValue rw;
private final SQLite3GlobalState globalState;
Expand Down Expand Up @@ -133,6 +135,7 @@ public static SQLite3Expression getRandomLiteralValue(SQLite3GlobalState globalS
return new SQLite3ExpressionGenerator(globalState).getRandomLiteralValueInternal(globalState.getRandomly());
}

@Override
public List<SQLite3Expression> generateOrderBys() {
List<SQLite3Expression> expressions = new ArrayList<>();
for (int i = 0; i < Randomly.smallNumber() + 1; i++) {
Expand Down Expand Up @@ -747,6 +750,18 @@ public List<SQLite3Expression> getTableRefs() {
return tableRefs;
}

@Override
public List<SQLite3Expression> generateFetchColumns(boolean shouldCreateDummy) {
List<SQLite3Expression> columns = new ArrayList<>();
if (shouldCreateDummy && Randomly.getBoolean()) {
columns.add(new SQLite3ColumnName(SQLite3Column.createDummy("*"), null));
} else {
columns = Randomly.nonEmptySubset(this.columns).stream().map(c -> new SQLite3ColumnName(c, null))
.collect(Collectors.toList());
}
return columns;
}

@Override
public String generateOptimizedQueryString(SQLite3Select select, SQLite3Expression whereCondition,
boolean shouldUseAggregate) {
Expand Down
49 changes: 18 additions & 31 deletions src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPWhereOracle.java
Original file line number Diff line number Diff line change
@@ -1,50 +1,37 @@
package sqlancer.sqlite3.oracle.tlp;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

import sqlancer.ComparatorHelper;
import sqlancer.Randomly;
import sqlancer.common.oracle.TLPWhereOracle;
import sqlancer.common.oracle.TestOracle;
import sqlancer.common.query.ExpectedErrors;
import sqlancer.sqlite3.SQLite3Errors;
import sqlancer.sqlite3.SQLite3GlobalState;
import sqlancer.sqlite3.SQLite3Visitor;
import sqlancer.sqlite3.ast.SQLite3Expression;
import sqlancer.sqlite3.ast.SQLite3Select;
import sqlancer.sqlite3.gen.SQLite3ExpressionGenerator;
import sqlancer.sqlite3.schema.SQLite3Schema;
import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Column;
import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Table;

public class SQLite3TLPWhereOracle extends SQLite3TLPBase {
public class SQLite3TLPWhereOracle implements TestOracle<SQLite3GlobalState> {

private String generatedQueryString;
private final TLPWhereOracle<SQLite3Select, SQLite3Expression.Join, SQLite3Expression, SQLite3Schema, SQLite3Table, SQLite3Column, SQLite3GlobalState> oracle;

public SQLite3TLPWhereOracle(SQLite3GlobalState state) {
super(state);
SQLite3ExpressionGenerator gen = new SQLite3ExpressionGenerator(state);
ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(SQLite3Errors.getExpectedExpressionErrors())
.build();
this.oracle = new TLPWhereOracle<>(state, gen, expectedErrors);
}

@Override
public void check() throws SQLException {
super.check();
select.setWhereClause(null);
String originalQueryString = SQLite3Visitor.asString(select);
generatedQueryString = originalQueryString;
List<String> resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state);

boolean orderBy = Randomly.getBooleanWithSmallProbability();
if (orderBy) {
select.setOrderByClauses(gen.generateOrderBys());
}
select.setWhereClause(predicate);
String firstQueryString = SQLite3Visitor.asString(select);
select.setWhereClause(negatedPredicate);
String secondQueryString = SQLite3Visitor.asString(select);
select.setWhereClause(isNullPredicate);
String thirdQueryString = SQLite3Visitor.asString(select);
List<String> combinedString = new ArrayList<>();
List<String> secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString,
thirdQueryString, combinedString, !orderBy, state, errors);
ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString,
state);
oracle.check();
}

@Override
public String getLastQueryString() {
return generatedQueryString;
return oracle.getLastQueryString();
}

}
Loading