Skip to content

Commit

Permalink
Add support for Collectors.groupingBy; required some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
msridhar committed Mar 26, 2024
1 parent ff1b191 commit d317d0e
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* THE SOFTWARE.
*/

import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.LinkedHashMultimap;
Expand Down Expand Up @@ -125,11 +126,26 @@ class StreamNullabilityPropagator extends BaseNoOpHandler {
private final Map<MethodInvocationTree, Tree> observableCallToInnerMethodOrLambda =
new LinkedHashMap<>();

// Maps collect calls in the observable call chain to the relevant inner methods or lambdas.
@AutoValue
abstract static class CollectRecordAndInnerMethod {

static CollectRecordAndInnerMethod create(
CollectLikeMethodRecord collectlikeMethodRecord, Tree innerMethodOrLambda) {
return new AutoValue_StreamNullabilityPropagator_CollectRecordAndInnerMethod(
collectlikeMethodRecord, innerMethodOrLambda);
}

abstract CollectLikeMethodRecord getCollectLikeMethodRecord();

abstract Tree getInnerMethodOrLambda();
}

// Maps collect calls in the observable call chain to the relevant (collect record, inner method
// or lambda) pairs.
// We need a Multimap here since there may be multiple relevant methods / lambdas.
// E.g.: stream.filter(...).collect(Collectors.toMap(l1, l2)) => {l1,l2}
private final Multimap<MethodInvocationTree, Tree> collectCallToInnerMethodsOrLambdas =
LinkedHashMultimap.create();
// E.g.: stream.filter(...).collect(Collectors.toMap(l1, l2)) => (record for toMap, {l1,l2})
private final Multimap<MethodInvocationTree, CollectRecordAndInnerMethod>
collectCallToRecordsAndInnerMethodsOrLambdas = LinkedHashMultimap.create();

// Map from map or collect method (or lambda) to corresponding previous filter method (e.g.
// B.apply => A.filter for the map example above, or l1 => A.filter and l2 => A.filter for the
Expand Down Expand Up @@ -183,7 +199,7 @@ public void onMatchTopLevelClass(
this.filterMethodOrLambdaSet.clear();
this.observableOuterCallInChain.clear();
this.observableCallToInnerMethodOrLambda.clear();
this.collectCallToInnerMethodsOrLambdas.clear();
this.collectCallToRecordsAndInnerMethodsOrLambdas.clear();
this.mapOrCollectRecordToFilterMap.clear();
this.filterToNSMap.clear();
this.bodyToMethodOrLambda.clear();
Expand Down Expand Up @@ -237,10 +253,14 @@ public void onMatchMethodInvocation(
observableCallToInnerMethodOrLambda.put(tree, argTree);
}
} else {
CollectLikeMethodRecord collectlikeMethodRecord =
streamType.getCollectlikeMethodRecord(methodSymbol);
if (collectlikeMethodRecord != null && methodSymbol.getParameters().length() == 1) {
handleCollectCall(tree, collectlikeMethodRecord);
if (methodSymbol.getParameters().length() == 1) {
for (CollectLikeMethodRecord collectlikeMethodRecord :
streamType.getCollectlikeMethodRecords(methodSymbol)) {
boolean handled = handleCollectCall(tree, collectlikeMethodRecord);
if (handled) {
break;
}
}
}
}
}
Expand All @@ -249,13 +269,15 @@ public void onMatchMethodInvocation(

/**
* Handles a call to a collect-like method. If the argument to the method is supported, updates
* the {@link #collectCallToInnerMethodsOrLambdas} map appropriately.
* the {@link #collectCallToRecordsAndInnerMethodsOrLambdas} map appropriately.
*
* @param collectInvocationTree The MethodInvocationTree representing the call to the collect-like
* method.
* @param collectlikeMethodRecord The record representing the collect-like method.
* @return true if the argument to the collect method was a call to the factory method in the
* record, false otherwise.
*/
private void handleCollectCall(
private boolean handleCollectCall(
MethodInvocationTree collectInvocationTree, CollectLikeMethodRecord collectlikeMethodRecord) {
ExpressionTree argTree = collectInvocationTree.getArguments().get(0);
if (argTree instanceof MethodInvocationTree) {
Expand All @@ -282,14 +304,21 @@ private void handleCollectCall(
handleMapOrCollectAnonClassBody(
collectlikeMethodRecord,
anonClassBody,
t -> collectCallToInnerMethodsOrLambdas.put(collectInvocationTree, t));
t ->
collectCallToRecordsAndInnerMethodsOrLambdas.put(
collectInvocationTree,
CollectRecordAndInnerMethod.create(collectlikeMethodRecord, t)));
}
} else if (factoryMethodArg instanceof LambdaExpressionTree) {
collectCallToInnerMethodsOrLambdas.put(collectInvocationTree, factoryMethodArg);
collectCallToRecordsAndInnerMethodsOrLambdas.put(
collectInvocationTree,
CollectRecordAndInnerMethod.create(collectlikeMethodRecord, factoryMethodArg));
}
}
return true;
}
}
return false;
}

private void buildObservableCallChain(MethodInvocationTree tree) {
Expand Down Expand Up @@ -326,19 +355,15 @@ private void handleChainFromFilter(
mapOrCollectRecordToFilterMap.put(
observableCallToInnerMethodOrLambda.get(outerCallInChain), record);
}
} else if (collectCallToInnerMethodsOrLambdas.containsKey(outerCallInChain)) {
Symbol.MethodSymbol collectMethod = ASTHelpers.getSymbol(outerCallInChain);
CollectLikeMethodRecord collectlikeMethodRecord =
streamType.getCollectlikeMethodRecord(collectMethod);
if (collectlikeMethodRecord != null) {
// Update mapOrCollectRecordToFilterMap for all relevant methods / lambdas
for (Tree innerMethodOrLambda :
collectCallToInnerMethodsOrLambdas.get(outerCallInChain)) {
MapOrCollectMethodToFilterInstanceRecord record =
new MapOrCollectMethodToFilterInstanceRecord(
collectlikeMethodRecord, filterMethodOrLambda);
mapOrCollectRecordToFilterMap.put(innerMethodOrLambda, record);
}
} else if (collectCallToRecordsAndInnerMethodsOrLambdas.containsKey(outerCallInChain)) {
// Update mapOrCollectRecordToFilterMap for all relevant methods / lambdas
for (CollectRecordAndInnerMethod collectRecordAndInnerMethod :
collectCallToRecordsAndInnerMethodsOrLambdas.get(outerCallInChain)) {
MapOrCollectMethodToFilterInstanceRecord record =
new MapOrCollectMethodToFilterInstanceRecord(
collectRecordAndInnerMethod.getCollectLikeMethodRecord(), filterMethodOrLambda);
mapOrCollectRecordToFilterMap.put(
collectRecordAndInnerMethod.getInnerMethodOrLambda(), record);
}
}
} while (outerCallInChain != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ public static StreamNullabilityPropagator getJavaStreamNullabilityPropagator() {
ImmutableSet.of(0, 1),
"apply",
ImmutableSet.of(0))
.withCollectMethodFromSignature(
"<R,A>collect(java.util.stream.Collector<? super T,A,R>)",
"java.util.stream.Collectors",
"<T,K>groupingBy(java.util.function.Function<? super T,? extends K>)",
ImmutableSet.of(0),
"apply",
ImmutableSet.of(0))
// List of methods of java.util.stream.Stream through which we just propagate the
// nullability information of the last call, e.g. m() in
// Observable.filter(...).m().map(...) means the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.predicates.TypePredicate;
import com.google.errorprone.predicates.type.DescendantOf;
Expand All @@ -48,7 +49,7 @@ public class StreamModelBuilder {
private ImmutableSet.Builder<String> filterMethodSimpleNames;
private ImmutableMap.Builder<String, MapLikeMethodRecord> mapMethodSigToRecord;
private ImmutableMap.Builder<String, MapLikeMethodRecord> mapMethodSimpleNameToRecord;
private ImmutableMap.Builder<String, CollectLikeMethodRecord> collectMethodSigToRecord;
private ImmutableMultimap.Builder<String, CollectLikeMethodRecord> collectMethodSigToRecords;
private ImmutableSet.Builder<String> passthroughMethodSigs;
private ImmutableSet.Builder<String> passthroughMethodSimpleNames;

Expand All @@ -75,7 +76,7 @@ private void finalizeOpenStreamTypeRecord() {
filterMethodSimpleNames.build(),
mapMethodSigToRecord.build(),
mapMethodSimpleNameToRecord.build(),
collectMethodSigToRecord.build(),
collectMethodSigToRecords.build(),
passthroughMethodSigs.build(),
passthroughMethodSimpleNames.build()));
}
Expand Down Expand Up @@ -109,7 +110,7 @@ private void initializeBuilders() {
this.filterMethodSimpleNames = ImmutableSet.builder();
this.mapMethodSigToRecord = ImmutableMap.builder();
this.mapMethodSimpleNameToRecord = ImmutableMap.builder();
this.collectMethodSigToRecord = ImmutableMap.builder();
this.collectMethodSigToRecords = ImmutableMultimap.builder();
this.passthroughMethodSigs = ImmutableSet.builder();
this.passthroughMethodSimpleNames = ImmutableSet.builder();
}
Expand Down Expand Up @@ -200,7 +201,7 @@ public StreamModelBuilder withCollectMethodFromSignature(
ImmutableSet<Integer> argsToCollectorFactoryMethod,
String innerMethodName,
ImmutableSet<Integer> argsFromStream) {
this.collectMethodSigToRecord.put(
this.collectMethodSigToRecords.put(
collectMethodSig,
CollectLikeMethodRecord.create(
collectorFactoryMethodClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
*/
import static com.uber.nullaway.NullabilityUtil.castToNonNull;

import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.VisitorState;
import com.google.errorprone.predicates.TypePredicate;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Type;
import javax.annotation.Nullable;

/** An immutable model describing a class from a stream-based API such as RxJava. */
public class StreamTypeRecord {
Expand All @@ -47,7 +48,7 @@ public class StreamTypeRecord {
private final ImmutableMap<String, MapLikeMethodRecord> mapMethodSigToRecord;
private final ImmutableMap<String, MapLikeMethodRecord> mapMethodSimpleNameToRecord;

private final ImmutableMap<String, CollectLikeMethodRecord> collectMethodSigToRecord;
private final ImmutableMultimap<String, CollectLikeMethodRecord> collectMethodSigToRecords;

// List of methods of java.util.stream.Stream through which we just propagate the nullability
// information of the last call, e.g. m() in Observable.filter(...).m().map(...) means the
Expand All @@ -64,15 +65,15 @@ public StreamTypeRecord(
ImmutableSet<String> filterMethodSimpleNames,
ImmutableMap<String, MapLikeMethodRecord> mapMethodSigToRecord,
ImmutableMap<String, MapLikeMethodRecord> mapMethodSimpleNameToRecord,
ImmutableMap<String, CollectLikeMethodRecord> collectMethodSigToRecord,
ImmutableMultimap<String, CollectLikeMethodRecord> collectMethodSigToRecords,
ImmutableSet<String> passthroughMethodSigs,
ImmutableSet<String> passthroughMethodSimpleNames) {
this.typePredicate = typePredicate;
this.filterMethodSigs = filterMethodSigs;
this.filterMethodSimpleNames = filterMethodSimpleNames;
this.mapMethodSigToRecord = mapMethodSigToRecord;
this.mapMethodSimpleNameToRecord = mapMethodSimpleNameToRecord;
this.collectMethodSigToRecord = collectMethodSigToRecord;
this.collectMethodSigToRecords = collectMethodSigToRecords;
this.passthroughMethodSigs = passthroughMethodSigs;
this.passthroughMethodSimpleNames = passthroughMethodSimpleNames;
}
Expand Down Expand Up @@ -101,9 +102,9 @@ record =
return record;
}

@Nullable
public CollectLikeMethodRecord getCollectlikeMethodRecord(Symbol.MethodSymbol methodSymbol) {
return collectMethodSigToRecord.get(methodSymbol.toString());
public ImmutableCollection<CollectLikeMethodRecord> getCollectlikeMethodRecords(
Symbol.MethodSymbol methodSymbol) {
return collectMethodSigToRecords.get(methodSymbol.toString());
}

public boolean isPassthroughMethod(Symbol.MethodSymbol methodSymbol) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,37 @@ public void streamSupportCollectorsToMap() {
.doTest();
}

@Test
public void streamSupportCollectorsGroupingBy() {
defaultCompilationHelper
.addSourceLines(
"Test.java",
"package com.uber;",
"import java.util.*;",
"import java.util.stream.*;",
"import java.util.function.Function;",
"import javax.annotation.Nullable;",
"class Test {",
" static class Foo {",
" @Nullable String bar;",
" String baz = \"baz\";",
" }",
" Map<Integer, List<Foo>> testNegative() {",
" List<Foo> foos = new ArrayList<>();",
" return foos.stream()",
" .filter(foo -> foo.bar != null)",
" .collect(Collectors.groupingBy(foo -> foo.bar.length()));",
" }",
" Map<Integer, List<Foo>> testPositive1() {",
" List<Foo> foos = new ArrayList<>();",
" return foos.stream()",
" // BUG: Diagnostic contains: dereferenced expression foo.bar is @Nullable",
" .collect(Collectors.groupingBy(foo -> foo.bar.length()));",
" }",
"}")
.doTest();
}

@Test
public void supportObjectsIsNull() {
defaultCompilationHelper
Expand Down

0 comments on commit d317d0e

Please sign in to comment.