Skip to content

Commit

Permalink
Use AutoValue for Category
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 321069159
  • Loading branch information
flamearrow authored and tflite-support-robot committed Jul 17, 2020
1 parent 25e46f7 commit 556594f
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 36 deletions.
59 changes: 59 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
workspace(name = "org_tensorflow_lite_support")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external")

http_archive(
name = "io_bazel_rules_closure",
Expand Down Expand Up @@ -170,6 +171,64 @@ http_archive(
build_file = "//third_party:libyuv.BUILD",
)

java_import_external(
name = "com_google_auto_value",
jar_sha256 = "fd811b92bb59ae8a4cf7eb9dedd208300f4ea2b6275d726e4df52d8334aaae9d",
jar_urls = [
"https://mirror.bazel.build/repo1.maven.org/maven2/com/google/auto/value/auto-value/1.6/auto-value-1.6.jar",
"https://repo1.maven.org/maven2/com/google/auto/value/auto-value/1.6/auto-value-1.6.jar",
],
licenses = ["notice"], # Apache 2.0
generated_rule_name = "processor",
exports = ["@com_google_auto_value_annotations"],
extra_build_file_content = "\n".join([
"java_plugin(",
" name = \"AutoAnnotationProcessor\",",
" output_licenses = [\"unencumbered\"],",
" processor_class = \"com.google.auto.value.processor.AutoAnnotationProcessor\",",
" tags = [\"annotation=com.google.auto.value.AutoAnnotation;genclass=${package}.AutoAnnotation_${outerclasses}${classname}_${methodname}\"],",
" deps = [\":processor\"],",
")",
"",
"java_plugin(",
" name = \"AutoOneOfProcessor\",",
" output_licenses = [\"unencumbered\"],",
" processor_class = \"com.google.auto.value.processor.AutoOneOfProcessor\",",
" tags = [\"annotation=com.google.auto.value.AutoValue;genclass=${package}.AutoOneOf_${outerclasses}${classname}\"],",
" deps = [\":processor\"],",
")",
"",
"java_plugin(",
" name = \"AutoValueProcessor\",",
" output_licenses = [\"unencumbered\"],",
" processor_class = \"com.google.auto.value.processor.AutoValueProcessor\",",
" tags = [\"annotation=com.google.auto.value.AutoValue;genclass=${package}.AutoValue_${outerclasses}${classname}\"],",
" deps = [\":processor\"],",
")",
"",
"java_library(",
" name = \"com_google_auto_value\",",
" exported_plugins = [",
" \":AutoAnnotationProcessor\",",
" \":AutoOneOfProcessor\",",
" \":AutoValueProcessor\",",
" ],",
" exports = [\"@com_google_auto_value_annotations\"],",
")",
]),
)

java_import_external(
name = "com_google_auto_value_annotations",
jar_sha256 = "d095936c432f2afc671beaab67433e7cef50bba4a861b77b9c46561b801fae69",
jar_urls = [
"https://mirror.bazel.build/repo1.maven.org/maven2/com/google/auto/value/auto-value-annotations/1.6/auto-value-annotations-1.6.jar",
"https://repo1.maven.org/maven2/com/google/auto/value/auto-value-annotations/1.6/auto-value-annotations-1.6.jar",
],
licenses = ["notice"], # Apache 2.0
neverlink = True,
default_visibility = ["@com_google_auto_value//:__pkg__"],
)

load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")

Expand Down
1 change: 1 addition & 0 deletions tensorflow_lite_support/java/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ android_library(
javacopts = JAVACOPTS,
manifest = "AndroidManifest.xml",
deps = [
"@com_google_auto_value",
"@org_checkerframework_qual",
"@org_tensorflow//tensorflow/lite/java:tensorflowlite",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,48 +15,23 @@

package org.tensorflow.lite.support.label;

import java.util.Objects;
import com.google.auto.value.AutoValue;
import org.tensorflow.lite.annotations.UsedByReflection;

/**
* Category is a util class, contains a label and a float value. Typically it's used as result of
* classification tasks.
*/
public final class Category {
private final String label;
private final float score;

/** Constructs a Category. */
public Category(String label, float score) {
this.label = label;
this.score = score;
}

/** Gets the reference of category's label. */
public String getLabel() {
return label;
}
@AutoValue
@UsedByReflection("ClassifierJNI")
public abstract class Category {

/** Gets the score of the category. */
public float getScore() {
return score;
@UsedByReflection("ClassifierJNI")
public static Category create(String label, double score) {
return new AutoValue_Category(label, score);
}

@Override
public boolean equals(Object o) {
if (o instanceof Category) {
Category other = (Category) o;
return (other.getLabel().equals(this.label) && other.getScore() == this.score);
}
return false;
}
public abstract String getLabel();

@Override
public int hashCode() {
return Objects.hash(label, score);
}

@Override
public String toString() {
return "<Category \"" + label + "\" (score=" + score + ")>";
}
public abstract double getScore();
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public List<Category> getCategoryList() {
List<Category> result = new ArrayList<>();
int i = 0;
for (String label : labels) {
result.add(new Category(label, data[i]));
result.add(Category.create(label, data[i]));
i += 1;
}
return result;
Expand Down

0 comments on commit 556594f

Please sign in to comment.