Skip to content

Commit

Permalink
Propagate more nullability info to lambdas known to be invoked synchr…
Browse files Browse the repository at this point in the history
…onously (#952)

Fixes #941

We propagate full nullability info from the enclosing context to
callbacks passed to `Map.forEach`, `Iterable.forEach`, `List.removeIf`,
and all methods on `java.util.stream.Stream`
  • Loading branch information
msridhar committed Apr 30, 2024
1 parent 719b167 commit edcfe49
Show file tree
Hide file tree
Showing 12 changed files with 350 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ default ImmutableList<StreamTypeRecord> customStreamNullabilitySpecs() {
* </code>
* </ul>
*/
final class MethodRef {
public final class MethodRef {

public final String enclosingClass;

Expand Down
3 changes: 2 additions & 1 deletion nullaway/src/main/java/com/uber/nullaway/NullAway.java
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ private void updateEnvironmentMapping(TreePath treePath, VisitorState state) {
// 2. we keep info on all locals rather than just effectively final ones for simplicity
EnclosingEnvironmentNullness.instance(state.context)
.addEnvironmentMapping(
treePath.getLeaf(), analysis.getNullnessInfoBeforeNewContext(treePath, state, handler));
treePath.getLeaf(),
analysis.getNullnessInfoBeforeNestedMethodNode(treePath, state, handler));
}

private Symbol.MethodSymbol getSymbolOfSuperConstructor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,23 @@ public Set<Element> getNonnullStaticFieldsBefore(TreePath path, Context context)
}

/**
* Get nullness info for local variables (and final fields) before some node
* Get nullness info for local variables (and final fields) before some node represented a nested
* method (lambda or anonymous class)
*
* @param path tree path to some AST node within a method / lambda / initializer
* @param pathToNestedMethodNode tree path to some AST node representing a nested method
* @param state visitor state
* @return nullness info for local variables just before the node
* @param handler handler instance
* @return nullness info for local variables just before the leaf of the tree path
*/
public NullnessStore getNullnessInfoBeforeNewContext(
TreePath path, VisitorState state, Handler handler) {
NullnessStore store = dataFlow.resultBefore(path, state.context, nullnessPropagation);
public NullnessStore getNullnessInfoBeforeNestedMethodNode(
TreePath pathToNestedMethodNode, VisitorState state, Handler handler) {
NullnessStore store =
dataFlow.resultBefore(pathToNestedMethodNode, state.context, nullnessPropagation);
if (store == null) {
return NullnessStore.empty();
}
Predicate<AccessPath> handlerPredicate =
handler.getAccessPathPredicateForNestedMethod(pathToNestedMethodNode, state);
return store.filterAccessPaths(
(ap) -> {
boolean allAPNonRootElementsAreFinalFields = true;
Expand All @@ -243,7 +248,7 @@ public NullnessStore getNullnessInfoBeforeNewContext(
&& e.getModifiers().contains(Modifier.FINAL));
}

return handler.includeApInfoInSavedContext(ap, state);
return handlerPredicate.test(ap);
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.uber.nullaway.handlers;

import com.google.errorprone.VisitorState;
import com.sun.source.util.TreePath;
import com.uber.nullaway.dataflow.AccessPath;
import java.util.function.Predicate;

/**
* {@link java.util.function.Predicate}s over {@link com.uber.nullaway.dataflow.AccessPath}s useful
* in defining handlers.
*/
public class AccessPathPredicates {

/**
* An AccessPath predicate that always returns false. Used to optimize {@link
* CompositeHandler#getAccessPathPredicateForNestedMethod(TreePath, VisitorState)}
*/
static final Predicate<AccessPath> FALSE_AP_PREDICATE = ap -> false;

/**
* An AccessPath predicate that always returns true. Used to optimize {@link
* CompositeHandler#getAccessPathPredicateForNestedMethod(TreePath, VisitorState)}
*/
static final Predicate<AccessPath> TRUE_AP_PREDICATE = ap -> true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.ReturnTree;
import com.sun.source.util.TreePath;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Types;
import com.sun.tools.javac.util.Context;
Expand All @@ -44,6 +45,7 @@
import com.uber.nullaway.dataflow.cfg.NullAwayCFGBuilder;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import org.checkerframework.nullaway.dataflow.cfg.UnderlyingAST;
import org.checkerframework.nullaway.dataflow.cfg.node.FieldAccessNode;
Expand Down Expand Up @@ -199,8 +201,9 @@ public Optional<ErrorMessage> onExpressionDereference(
}

@Override
public boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState state) {
return false;
public Predicate<AccessPath> getAccessPathPredicateForNestedMethod(
TreePath path, VisitorState state) {
return AccessPathPredicates.FALSE_AP_PREDICATE;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

package com.uber.nullaway.handlers;

import static com.uber.nullaway.handlers.AccessPathPredicates.FALSE_AP_PREDICATE;
import static com.uber.nullaway.handlers.AccessPathPredicates.TRUE_AP_PREDICATE;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.VisitorState;
Expand All @@ -32,6 +35,7 @@
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.ReturnTree;
import com.sun.source.util.TreePath;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Types;
import com.sun.tools.javac.util.Context;
Expand All @@ -45,6 +49,7 @@
import com.uber.nullaway.dataflow.cfg.NullAwayCFGBuilder;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import org.checkerframework.nullaway.dataflow.cfg.UnderlyingAST;
import org.checkerframework.nullaway.dataflow.cfg.node.FieldAccessNode;
Expand Down Expand Up @@ -253,12 +258,24 @@ public Optional<ErrorMessage> onExpressionDereference(
}

@Override
public boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState state) {
boolean shouldFilter = false;
public Predicate<AccessPath> getAccessPathPredicateForNestedMethod(
TreePath path, VisitorState state) {
Predicate<AccessPath> filter = FALSE_AP_PREDICATE;
for (Handler h : handlers) {
shouldFilter |= h.includeApInfoInSavedContext(accessPath, state);
Predicate<AccessPath> curFilter = h.getAccessPathPredicateForNestedMethod(path, state);
// here we do some optimization, to try to avoid unnecessarily returning a deeply nested
// Predicate object (which would be more costly to test)
if (curFilter != FALSE_AP_PREDICATE) {
if (curFilter == TRUE_AP_PREDICATE) {
return curFilter;
} else if (filter == FALSE_AP_PREDICATE) {
filter = curFilter;
} else {
filter = filter.or(curFilter);
}
}
}
return shouldFilter;
return filter;
}

@Override
Expand Down
15 changes: 9 additions & 6 deletions nullaway/src/main/java/com/uber/nullaway/handlers/Handler.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.ReturnTree;
import com.sun.source.util.TreePath;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Types;
import com.sun.tools.javac.util.Context;
Expand All @@ -45,6 +46,7 @@
import com.uber.nullaway.dataflow.cfg.NullAwayCFGBuilder;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import org.checkerframework.nullaway.dataflow.cfg.UnderlyingAST;
import org.checkerframework.nullaway.dataflow.cfg.node.FieldAccessNode;
Expand Down Expand Up @@ -327,15 +329,16 @@ Optional<ErrorMessage> onExpressionDereference(
ExpressionTree expr, ExpressionTree baseExpr, VisitorState state);

/**
* Called when the store access paths are filtered for local variable information before an
* expression.
* Called when determining which access path nullability information should be preserved when
* analyzing a nested method, i.e., a lambda expression or a method in an anonymous or local
* class.
*
* @param accessPath The access path that needs to be checked if filtered.
* @param path The tree path to the node for the nested method.
* @param state The current visitor state.
* @return true if the nullability information for this accesspath should be treated as part of
* the surrounding context when processing a lambda expression or anonymous class declaration.
* @return A predicate that determines which access paths should be preserved when analyzing the
* nested method.
*/
boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState state);
Predicate<AccessPath> getAccessPathPredicateForNestedMethod(TreePath path, VisitorState state);

/**
* Called during dataflow analysis initialization to register structurally immutable types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public static Handler buildDefault(Config config) {
handlerListBuilder.add(new GrpcHandler());
handlerListBuilder.add(new RequiresNonNullHandler());
handlerListBuilder.add(new EnsuresNonNullHandler());
handlerListBuilder.add(new SynchronousCallbackHandler());
if (config.serializationIsActive() && config.getSerializationConfig().fieldInitInfoEnabled) {
handlerListBuilder.add(
new FieldInitializationSerializationHandler(config.getSerializationConfig()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import javax.lang.model.element.AnnotationMirror;
import javax.lang.model.element.Element;
Expand Down Expand Up @@ -164,17 +165,18 @@ private boolean isOptionalContentNullable(
}

@Override
public boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState state) {

if (accessPath.getElements().size() == 1) {
final Element e = accessPath.getRoot();
if (e != null) {
return e.getKind().equals(ElementKind.LOCAL_VARIABLE)
&& accessPath.getElements().get(0).getJavaElement()
instanceof OptionalContentVariableElement;
public Predicate<AccessPath> getAccessPathPredicateForNestedMethod(
TreePath path, VisitorState state) {
return ap -> {
if (ap.getElements().size() == 1) {
final Element e = ap.getRoot();
if (e != null) {
return e.getKind().equals(ElementKind.LOCAL_VARIABLE)
&& ap.getElements().get(0).getJavaElement() instanceof OptionalContentVariableElement;
}
}
}
return false;
return false;
};
}

private void handleTestAssertions(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package com.uber.nullaway.handlers;

import static com.uber.nullaway.handlers.AccessPathPredicates.FALSE_AP_PREDICATE;
import static com.uber.nullaway.handlers.AccessPathPredicates.TRUE_AP_PREDICATE;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.VisitorState;
import com.google.errorprone.suppliers.Supplier;
import com.google.errorprone.suppliers.Suppliers;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.ClassTree;
import com.sun.source.tree.LambdaExpressionTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.Tree;
import com.sun.source.util.TreePath;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Type;
import com.uber.nullaway.LibraryModels.MethodRef;
import com.uber.nullaway.dataflow.AccessPath;
import java.util.function.Predicate;

public class SynchronousCallbackHandler extends BaseNoOpHandler {

/**
* Maps method name to full information about the corresponding methods and what parameter is the
* relevant callback. We key on method name to quickly eliminate most cases when doing a lookup.
*/
private static final ImmutableMap<String, ImmutableMap<MethodRef, Integer>>
METHOD_NAME_TO_SIG_AND_PARAM_INDEX =
ImmutableMap.of(
"forEach",
ImmutableMap.of(
MethodRef.methodRef(
"java.util.Map",
"forEach(java.util.function.BiConsumer<? super K,? super V>)"),
0,
MethodRef.methodRef(
"java.lang.Iterable", "forEach(java.util.function.Consumer<? super T>)"),
0),
"removeIf",
ImmutableMap.of(
MethodRef.methodRef(
"java.util.Collection", "removeIf(java.util.function.Predicate<? super E>)"),
0));

private static final Supplier<Type> STREAM_TYPE_SUPPLIER =
Suppliers.typeFromString("java.util.stream.Stream");

@Override
public Predicate<AccessPath> getAccessPathPredicateForNestedMethod(
TreePath path, VisitorState state) {
Tree leafNode = path.getLeaf();
Preconditions.checkArgument(
leafNode instanceof ClassTree || leafNode instanceof LambdaExpressionTree,
"Unexpected leaf type: %s",
leafNode.getClass());
Tree parentNode = path.getParentPath().getLeaf();
if (parentNode instanceof MethodInvocationTree) {
MethodInvocationTree methodInvocationTree = (MethodInvocationTree) parentNode;
Symbol.MethodSymbol symbol = ASTHelpers.getSymbol(methodInvocationTree);
if (symbol == null) {
return FALSE_AP_PREDICATE;
}
Type ownerType = symbol.owner.type;
if (ASTHelpers.isSameType(ownerType, STREAM_TYPE_SUPPLIER.get(state), state)) {
// preserve access paths for all callbacks passed to stream methods
return TRUE_AP_PREDICATE;
}
String invokedMethodName = symbol.getSimpleName().toString();
if (METHOD_NAME_TO_SIG_AND_PARAM_INDEX.containsKey(invokedMethodName)) {
ImmutableMap<MethodRef, Integer> entriesForMethodName =
METHOD_NAME_TO_SIG_AND_PARAM_INDEX.get(invokedMethodName);
for (MethodRef methodRef : entriesForMethodName.keySet()) {
if (symbol.toString().equals(methodRef.fullMethodSig)
&& ASTHelpers.isSubtype(
ownerType, state.getTypeFromString(methodRef.enclosingClass), state)) {
int parameterIndex = -1;
for (int i = 0; i < methodInvocationTree.getArguments().size(); i++) {
if (methodInvocationTree.getArguments().get(i) == leafNode) {
parameterIndex = i;
break;
}
}
if (parameterIndex == entriesForMethodName.get(methodRef)) {
return TRUE_AP_PREDICATE;
}
}
}
}
}
return FALSE_AP_PREDICATE;
}
}

0 comments on commit edcfe49

Please sign in to comment.