Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding layers (based on keras) supporting multiple outputs #65

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package org.tensorflow.framework.layers;

import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.family.TType;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

public abstract class Layer<T extends TType> extends Module<T> implements LayerFunction<T> {
private final boolean trainable;
private final boolean dynamic;
private final DataType<T> dtype;

public List<Node<T>> inboundNodes;
public List<Node<T>> outboundNodes;

protected boolean built;

public Layer(Ops tf, String name, boolean trainable, boolean dynamic, DataType<T> dtype) {
super(tf, name, dtype);
this.trainable = trainable;
this.dynamic = dynamic;
this.dtype = dtype;
}

/**
* Builds this layer (add layer weights) NOTE: This method MUST set `built` to true
*
* <p>{@code this.built = true}
*/
public abstract void build(List<Shape> inputShapes);

public abstract List<Shape> computeOutputShapes(List<Shape> inputShapes);

protected abstract List<Operand<T>> call(List<Operand<T>> inputs);

@SafeVarargs
public final List<Operand<T>> apply(Operand<T>... inputs) {
return apply(Arrays.asList(inputs));
}

@Override
public final List<Operand<T>> apply(List<Operand<T>> inputs) {
if (!isBuilt()) throw new IllegalStateException("Cannot call a layer until it is built.");

if (isDynamic() && tf.scope().env().isGraph())
throw new IllegalStateException("Dynamic layers can only be used " + "in eager mode.");

List<Shape> expectedOutputShapes = computeOutputShapes(getShapes(inputs));
List<Operand<T>> outputs = call(inputs);

for (int i = 0; i < inputs.size(); i++) {
if (expectedOutputShapes.get(i) != outputs.get(i).asOutput().shape()) {
throw new IllegalStateException(
"Shape "
+ outputs.get(i).asOutput().shape()
+ " at output "
+ i
+ "does not "
+ "match expected shape "
+ expectedOutputShapes.get(i));
}
}

return outputs;
}

@Override
public Iterable<Module<T>> getDirectSubmodules() {
return Collections::emptyIterator;
}

/**
* Returns a list of all trainable and non-trainable weights (in that order)
*
* @return all the weights of this layer (concatenation of getTrainableWeights() and
* getNonTrainableWeights())
*/
public List<Variable<T>> getWeights() {
List<Variable<T>> weights = getTrainableWeights();
weights.addAll(getNonTrainableWeights());
return weights;
}

/**
* List of variables to be included in backpropagation
*
* @return all trainable weights of this layer
*/
public List<Variable<T>> getTrainableWeights() {
return getModuleWeights().stream()
.filter(w -> w.trainable)
.map(w -> w.variable)
.collect(Collectors.toList());
}

/**
* List of variables to be excluded from backpropagation
*
* @return all non-trainable weights of this layer
*/
public List<Variable<T>> getNonTrainableWeights() {
return getModuleWeights().stream()
.filter(w -> !w.trainable)
.map(w -> w.variable)
.collect(Collectors.toList());
}

private <U extends TType> List<Shape> getShapes(List<Operand<U>> operands) {
return operands.stream().map(op -> op.asOutput().shape()).collect(Collectors.toList());
}

public boolean isTrainable() {
return trainable;
}

public boolean isDynamic() {
return dynamic;
}

public boolean isBuilt() {
return built;
}

public DataType<T> getDtype() {
return dtype;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.tensorflow.framework.layers;

import org.tensorflow.Operand;
import org.tensorflow.types.family.TType;

import java.util.List;
import java.util.function.Function;

@FunctionalInterface
public interface LayerFunction<T extends TType>
extends Function<List<Operand<T>>, List<Operand<T>>> {
List<Operand<T>> apply(List<Operand<T>> inputs);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.tensorflow.framework.layers;

import org.tensorflow.DataType;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.family.TType;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

public abstract class Module<T extends TType> {
protected final Ops tf;
private final String name;
private List<ModuleVariable<T>> weights;

public Module(Ops tf, String name, DataType<T> dtype) {
this.tf = tf.withName(name);
this.name = name;
this.weights = new LinkedList<>();
}

public abstract Iterable<Module<T>> getDirectSubmodules();

public Iterable<Module<T>> getSubmodules(boolean recurse) {
if (!recurse) return getDirectSubmodules();

List<Module<T>> submodules = new ArrayList<>();

for (Module<T> module : getDirectSubmodules())
module.getSubmodules(true).forEach(submodules::add);

return submodules;
}

public Variable<T> addWeight(String name, boolean trainable, Shape shape, DataType<T> dtype) {
ModuleVariable<T> moduleVariable = new ModuleVariable<>(name, tf.variable(shape, dtype), trainable);
this.weights.add(moduleVariable);

return moduleVariable.variable;
}

List<ModuleVariable<T>> getModuleWeights() {
return StreamSupport.stream(getSubmodules(true).spliterator(), false)
.flatMap(module -> module.weights.stream())
.collect(Collectors.toList());
}

public String getName() {
return name;
}
}

class ModuleVariable<T extends TType> {
String name;
boolean trainable;
Variable<T> variable;

public ModuleVariable(String name, Variable<T> variable, boolean trainable) {
this.trainable = trainable;
this.variable = variable;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.tensorflow.framework.layers;

import org.tensorflow.Operand;
import org.tensorflow.types.family.TType;

import java.util.List;

public class Node<T extends TType> {
/** The Layer that takes input tensors and turns them into output tensors */
private Layer<T> outboundLayer;

/** The layers from which input tensors originate */
private List<Layer<T>> inboundLayers;

/**
* A list of integers, the same length as `inboundLayers`. `nodeIndices[i]` is the origin of
* inputTensors[i]
*/
private List<Integer> nodeIndices;

private Layer<T> layer;
private List<Operand<T>> outputs;

public Node(Layer<T> layer, List<Operand<T>> outputs) {
this.layer = layer;
this.outputs = outputs;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package org.tensorflow.framework.layers;

import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.family.TType;

import java.util.Arrays;
import java.util.List;

public class Sequential<T extends TType> extends Layer<T> {
private final List<Layer<T>> layers;

@SafeVarargs
public Sequential(Ops tf, DataType<T> dtype, Layer<T>... layers) {
super(tf, "Sequential", true, true, dtype);
this.layers = Arrays.asList(layers);
}

@Override
public void build(List<Shape> inputShapes) {
List<Shape> shapes = inputShapes;

for (Layer<T> layer : layers) {
layer.build(shapes);
shapes = layer.computeOutputShapes(shapes);
}
}

@Override
public List<Shape> computeOutputShapes(List<Shape> inputShapes) {
List<Shape> shapes = inputShapes;
for (Layer<T> layer : layers) {
shapes = layer.computeOutputShapes(shapes);
}

return shapes;
}

@Override
protected List<Operand<T>> call(List<Operand<T>> inputs) {
List<Operand<T>> outputs = inputs;
for (Layer<T> layer : layers) {
outputs = layer.call(inputs);
}

return outputs;
}

@Override
public Iterable<Module<T>> getDirectSubmodules() {
return (List<Module<T>>) (List<?>) layers;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.tensorflow.framework.layers;

import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.TFloat32;

import java.util.Collections;
import java.util.List;

class Dense extends Layer<TFloat32> {
private final int units;
public static Variable<TFloat32> kernel;
public static Variable<TFloat32> bias;

public Dense(Ops tf, int units) {
super(tf, "dense", 1, true, false, TFloat32.DTYPE);
this.units = units;
}

@Override
public void build(List<Shape> inputShapes) {
kernel = addWeight("KERNEL", true, inputShapes.get(0), TFloat32.DTYPE);
bias = addWeight("BIAS", true, inputShapes.get(0), TFloat32.DTYPE);
this.built = true;
}

@Override
public List<Shape> computeOutputShapes(List<Shape> inputShapes) {
return Collections.singletonList(inputShapes.get(0).replaceLast(units));
}

@Override
public List<Operand<TFloat32>> call(List<Operand<TFloat32>> inputs) {
Operand<TFloat32> input = inputs.get(0);
return Collections.singletonList(tf.math.add(tf.linalg.matMul(input, kernel), bias));
}

@Override
public Iterable<Module<TFloat32>> getDirectSubmodules() {
return Collections::emptyIterator;
}
}

public class LayerTest {}
24 changes: 24 additions & 0 deletions tensorflow-tools/src/main/java/org/tensorflow/tools/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,30 @@ public Shape prepend(long firstDimension) {
return Shape.of(newDimensions);
}

public boolean isKnown(int i) {
return dimensionSizes[i] != -1;
}

public void assertKnown(int i) {
if (!isKnown(i)) {
throw new IllegalStateException("Dimension " + i + " in shape needs to be known.");
}
}

public Shape replaceFirst(long dim) {
return replace(0, dim);
}

public Shape replaceLast(long dim) {
return replace(dimensionSizes.length - 1, dim);
}

public Shape replace(int i, long dim) {
Shape newShape = new Shape(Arrays.copyOf(dimensionSizes, dimensionSizes.length));
newShape.dimensionSizes[i] = dim;
return newShape;
}

private static long computeSize(long[] dimensionSizes) {
if (dimensionSizes == null) {
return UNKNOWN_SIZE;
Expand Down