Skip to content

Commit

Permalink
Refactor PyProvider helper utilities
Browse files Browse the repository at this point in the history
PyProvider manages how Java code creates and accesses the fields of the legacy "py" struct provider (which we'll migrate away from soonish). This CL refactors default value logic into PyProvider, and adds a builder so it's easier to construct. It also slightly refactors some shared-library logic in PyCommon. More cleanup (hopefully) to come.

Work toward bazelbuild#7010, yak shaving for bazelbuild#6583 and bazelbuild#1446.

RELNOTES: None
PiperOrigin-RevId: 229401739
  • Loading branch information
brandjon authored and weixiao-huang committed Jan 31, 2019
1 parent 6f03e7f commit 72472b6
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import com.google.devtools.build.lib.rules.cpp.CppFileTypes;
import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec;
import com.google.devtools.build.lib.syntax.EvalException;
import com.google.devtools.build.lib.syntax.SkylarkType;
import com.google.devtools.build.lib.syntax.Type;
import com.google.devtools.build.lib.util.FileType;
import com.google.devtools.build.lib.util.OS;
Expand Down Expand Up @@ -79,8 +78,22 @@ public void collectMetadataArtifacts(Iterable<Artifact> artifacts,

private final NestedSet<Artifact> transitivePythonSources;

private final PythonVersion sourcesVersion;
private final boolean usesSharedLibraries;

/**
* The Python major version for which this target is being built, as per the {@code
* python_version} attribute or the configuration.
*
* <p>This is always either {@code PY2} or {@code PY3}.
*/
private final PythonVersion version;

/**
* The level of compatibility with Python major versions, as per the {@code srcs_version}
* attribute.
*/
private final PythonVersion sourcesVersion;

private Map<PathFragment, Artifact> convertedFiles;

private NestedSet<Artifact> filesToBuild = null;
Expand All @@ -90,6 +103,7 @@ public PyCommon(RuleContext ruleContext) {
this.sourcesVersion = getSrcsVersionAttr(ruleContext);
this.version = ruleContext.getFragment(PythonConfiguration.class).getPythonVersion();
this.transitivePythonSources = collectTransitivePythonSources();
this.usesSharedLibraries = checkForSharedLibraries();
checkSourceIsCompatible(this.version, this.sourcesVersion, ruleContext.getLabel());
validatePythonVersionAttr(ruleContext);
}
Expand Down Expand Up @@ -155,7 +169,11 @@ public void addCommonTransitiveInfoProviders(
/* reportedToActualSources= */ NestedSetBuilder.create(Order.STABLE_ORDER)))
.addSkylarkTransitiveInfo(
PyProvider.PROVIDER_NAME,
PyProvider.create(this.transitivePythonSources, usesSharedLibraries(), imports))
PyProvider.builder()
.setTransitiveSources(transitivePythonSources)
.setUsesSharedLibraries(usesSharedLibraries)
.setImports(imports)
.build())
// Python targets are not really compilable. The best we can do is make sure that all
// generated source files are ready.
.addOutputGroup(OutputGroupInfo.FILES_TO_COMPILE, transitivePythonSources)
Expand Down Expand Up @@ -454,40 +472,50 @@ public NestedSet<Artifact> getFilesToBuild() {
return filesToBuild;
}

public boolean usesSharedLibraries() {
try {
return checkForSharedLibraries(Iterables.concat(
/**
* Returns true if any of this target's {@code deps} or {@code data} deps has a shared library
* file (e.g. a {@code .so}) in its transitive dependency closure.
*
* <p>For targets with the py provider, we consult the {@code uses_shared_libraries} field. For
* targets without this provider, we look for {@link CppFileTypes#SHARED_LIBRARY}-type files in
* the filesToBuild.
*/
private boolean checkForSharedLibraries() {
Iterable<? extends TransitiveInfoCollection> targets;
// For all rule types that use PyCommon, deps is a required attribute but not data.
if (ruleContext.attributes().has("data")) {
targets =
Iterables.concat(
ruleContext.getPrerequisites("deps", Mode.TARGET),
ruleContext.getPrerequisites("data", Mode.DONT_CHECK)));
ruleContext.getPrerequisites("data", Mode.DONT_CHECK));
} else {
targets = ruleContext.getPrerequisites("deps", Mode.TARGET);
}
try {
for (TransitiveInfoCollection target : targets) {
if (checkForSharedLibraries(target)) {
return true;
}
}
return false;
} catch (EvalException e) {
ruleContext.ruleError(e.getMessage());
return false;
}
}


/**
* Returns true if this target has an .so file in its transitive dependency closure.
*/
public static boolean checkForSharedLibraries(Iterable<TransitiveInfoCollection> deps)
throws EvalException{
for (TransitiveInfoCollection dep : deps) {
Object providerObject = dep.get(PyProvider.PROVIDER_NAME);
if (providerObject != null) {
SkylarkType.checkType(providerObject, StructImpl.class, null);
StructImpl provider = (StructImpl) providerObject;
Boolean isUsingSharedLibrary =
provider.getValue(PyProvider.USES_SHARED_LIBRARIES, Boolean.class);
if (Boolean.TRUE.equals(isUsingSharedLibrary)) {
return true;
}
} else if (FileType.contains(
dep.getProvider(FileProvider.class).getFilesToBuild(), CppFileTypes.SHARED_LIBRARY)) {
return true;
}
private static boolean checkForSharedLibraries(TransitiveInfoCollection target)
throws EvalException {
if (PyProvider.hasProvider(target)) {
return PyProvider.getUsesSharedLibraries(PyProvider.getProvider(target));
} else {
NestedSet<Artifact> files = target.getProvider(FileProvider.class).getFilesToBuild();
return FileType.contains(files, CppFileTypes.SHARED_LIBRARY);
}
}

return false;
public boolean usesSharedLibraries() {
return usesSharedLibraries;
}

private static String buildMultipleMainMatchesErrorText(boolean explicit, String proposedMainName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

package com.google.devtools.build.lib.rules.python;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.actions.Artifact;
import com.google.devtools.build.lib.analysis.TransitiveInfoCollection;
import com.google.devtools.build.lib.collect.nestedset.NestedSet;
import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
import com.google.devtools.build.lib.packages.StructImpl;
import com.google.devtools.build.lib.packages.StructProvider;
import com.google.devtools.build.lib.syntax.EvalException;
Expand All @@ -28,7 +30,7 @@
/**
* Static helper class for managing the "py" struct provider returned and consumed by Python rules.
*/
// TODO(brandjon): Replace this with a real provider.
// TODO(#7010): Replace this with a real provider.
public class PyProvider {

// Disable construction.
Expand All @@ -38,8 +40,8 @@ private PyProvider() {}
public static final String PROVIDER_NAME = "py";

/**
* Name of field holding a depset of transitive sources (i.e., .py files in srcs and in srcs of
* transitive deps).
* Name of field holding a postorder depset of transitive sources (i.e., .py files in srcs and in
* srcs of transitive deps).
*/
public static final String TRANSITIVE_SOURCES = "transitive_sources";

Expand All @@ -52,22 +54,22 @@ private PyProvider() {}
* Name of field holding a depset of import paths added by the transitive deps (including this
* target).
*/
// TODO(brandjon): Make this a pre-order depset, since higher-level targets should get precedence
// on PYTHONPATH.
// TODO(brandjon): Add assertions that this depset and transitive_sources have an order compatible
// with the one expected by the rules.
public static final String IMPORTS = "imports";

/** Constructs a provider instance with the given field values. */
public static StructImpl create(
NestedSet<Artifact> transitiveSources,
boolean usesSharedLibraries,
NestedSet<String> imports) {
return StructProvider.STRUCT.create(
ImmutableMap.of(
PyProvider.TRANSITIVE_SOURCES,
SkylarkNestedSet.of(Artifact.class, transitiveSources),
PyProvider.USES_SHARED_LIBRARIES,
usesSharedLibraries,
PyProvider.IMPORTS,
SkylarkNestedSet.of(String.class, imports)),
"No such attribute '%s'");
private static final ImmutableMap<String, Object> DEFAULTS;

static {
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
// TRANSITIVE_SOURCES is mandatory
builder.put(USES_SHARED_LIBRARIES, false);
builder.put(
IMPORTS,
SkylarkNestedSet.of(String.class, NestedSetBuilder.<String>compileOrder().build()));
DEFAULTS = builder.build();
}

/** Returns whether a given dependency has the py provider. */
Expand All @@ -92,8 +94,11 @@ public static StructImpl getProvider(TransitiveInfoCollection dep) throws EvalEx
private static Object getValue(StructImpl info, String fieldName) throws EvalException {
Object fieldValue = info.getValue(fieldName);
if (fieldValue == null) {
throw new EvalException(
/*location=*/ null, String.format("'py' provider missing '%s' field", fieldName));
fieldValue = DEFAULTS.get(fieldName);
if (fieldValue == null) {
throw new EvalException(
/*location=*/ null, String.format("'py' provider missing '%s' field", fieldName));
}
}
return fieldValue;
}
Expand All @@ -119,9 +124,9 @@ public static NestedSet<Artifact> getTransitiveSources(StructImpl info) throws E
}

/**
* Casts and returns the uses-shared-libraries field.
* Casts and returns the uses-shared-libraries field (or its default value).
*
* @throws EvalException if the field does not exist or is not a boolean
* @throws EvalException if the field exists and is not a boolean
*/
public static boolean getUsesSharedLibraries(StructImpl info) throws EvalException {
Object fieldValue = getValue(info, USES_SHARED_LIBRARIES);
Expand All @@ -136,9 +141,9 @@ public static boolean getUsesSharedLibraries(StructImpl info) throws EvalExcepti
}

/**
* Casts and returns the imports field.
* Casts and returns the imports field (or its default value).
*
* @throws EvalException if the field does not exist or is not a depset of strings
* @throws EvalException if the field exists and is not a depset of strings
*/
public static NestedSet<String> getImports(StructImpl info) throws EvalException {
Object fieldValue = getValue(info, IMPORTS);
Expand All @@ -154,4 +159,47 @@ public static NestedSet<String> getImports(StructImpl info) throws EvalException
EvalUtils.getDataTypeNameFromClass(fieldValue.getClass()));
return castValue.getSet(String.class);
}

public static Builder builder() {
return new Builder();
}

/** Builder for a py provider struct. */
public static class Builder {
SkylarkNestedSet transitiveSources = null;
Boolean usesSharedLibraries = null;
SkylarkNestedSet imports = null;

// Use the static builder() method instead.
private Builder() {}

public Builder setTransitiveSources(NestedSet<Artifact> transitiveSources) {
this.transitiveSources = SkylarkNestedSet.of(Artifact.class, transitiveSources);
return this;
}

public Builder setUsesSharedLibraries(boolean usesSharedLibraries) {
this.usesSharedLibraries = usesSharedLibraries;
return this;
}

public Builder setImports(NestedSet<String> imports) {
this.imports = SkylarkNestedSet.of(String.class, imports);
return this;
}

private static void put(
ImmutableMap.Builder<String, Object> fields, String fieldName, Object value) {
fields.put(fieldName, value != null ? value : DEFAULTS.get(fieldName));
}

public StructImpl build() {
ImmutableMap.Builder<String, Object> fields = ImmutableMap.builder();
Preconditions.checkNotNull(transitiveSources, "setTransitiveSources is required");
put(fields, TRANSITIVE_SOURCES, transitiveSources);
put(fields, USES_SHARED_LIBRARIES, usesSharedLibraries);
put(fields, IMPORTS, imports);
return StructProvider.STRUCT.create(fields.build(), "No such attribute '%s'");
}
}
}
Loading

0 comments on commit 72472b6

Please sign in to comment.