-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
SceneGraphImageAttribute.java
197 lines (155 loc) · 5.96 KB
/
SceneGraphImageAttribute.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
package edu.stanford.nlp.scenegraph.image;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
import javax.json.Json;
import javax.json.JsonArray;
import javax.json.JsonArrayBuilder;
import javax.json.JsonObject;
import javax.json.JsonObjectBuilder;
import javax.json.JsonString;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;
public class SceneGraphImageAttribute {
public String attribute;
public String object;
public String predicate = "is";
public SceneGraphImageRegion region;
public SceneGraphImageObject subject;
public String[] text;
public List<CoreLabel> attributeGloss;
public List<CoreLabel> subjectGloss;
public SceneGraphImage image;
@SuppressWarnings("unchecked")
public static SceneGraphImageAttribute fromJSONObject(SceneGraphImage img, JsonObject obj) {
SceneGraphImageAttribute attr = new SceneGraphImageAttribute();
attr.image = img;
attr.attribute = obj.getString("attribute");
attr.object = obj.getString("object");
attr.predicate = obj.getString("predicate");
if (obj.get("region") != null) {
int regionId = obj.getInt("region") - 1;
attr.region = img.regions.get(regionId);
}
int subjectId = obj.getInt("subject");
attr.subject = img.objects.get(subjectId);
List<String> textList = SceneGraphImageUtils.getJsonStringList(obj, "text");
attr.text = textList.toArray(new String[textList.size()]);
if (obj.containsKey("attributeGloss")) {
List<String> attributeGlossList = SceneGraphImageUtils.getJsonStringList(obj, "attributeGloss");
attr.attributeGloss = Generics.newArrayList(attributeGlossList.size());
for (String str : attributeGlossList) {
attr.attributeGloss.add(SceneGraphImageUtils.labelFromString(str));
}
}
if (obj.containsKey("subjectGloss")) {
List<String> subjectGlossList = SceneGraphImageUtils.getJsonStringList(obj, "subjectGloss");
attr.subjectGloss = Generics.newArrayList(subjectGlossList.size());
for (String str : subjectGlossList) {
attr.subjectGloss.add(SceneGraphImageUtils.labelFromString(str));
}
}
return attr;
}
@SuppressWarnings("unchecked")
public JsonObject toJSONObject(SceneGraphImage img) {
JsonObjectBuilder obj = Json.createObjectBuilder();
obj.add("attribute", this.attribute);
obj.add("object", this.object);
obj.add("predicate", this.predicate);
if (this.region != null) {
obj.add("region", img.regions.indexOf(this.region) + 1);
}
obj.add("subject", img.objects.indexOf(this.subject));
JsonArrayBuilder text = Json.createArrayBuilder();
for (String word : this.text) {
text.add(word);
}
obj.add("text", text.build());
if (this.attributeGloss != null) {
JsonArrayBuilder attributeGloss = Json.createArrayBuilder();
for (CoreLabel lbl : this.attributeGloss) {
attributeGloss.add(SceneGraphImageUtils.labelToString(lbl));
}
obj.add("attributeGloss", attributeGloss.build());
obj.add("attributeLemmaGloss", attributeLemmaGloss());
}
if (this.subjectGloss != null) {
JsonArrayBuilder subjectGloss = Json.createArrayBuilder();
for (CoreLabel lbl : this.subjectGloss) {
subjectGloss.add(SceneGraphImageUtils.labelToString(lbl));
}
obj.add("subjectGloss", subjectGloss.build());
obj.add("subjectLemmaGloss", subjectLemmaGloss());
}
return obj.build();
}
@Override
public SceneGraphImageAttribute clone() {
SceneGraphImageAttribute attr = new SceneGraphImageAttribute();
attr.attribute = this.attribute;
attr.object = this.object;
attr.predicate = this.predicate;
attr.region = this.region;
attr.subject = this.subject;
attr.text = Arrays.copyOf(this.text, this.text.length);
attr.image = this.image;
if (this.subjectGloss != null) {
attr.subjectGloss = Generics.newArrayList(this.subjectGloss.size());
for (CoreLabel lbl : this.subjectGloss) {
attr.subjectGloss.add(new CoreLabel(lbl));
}
}
if (this.attributeGloss != null) {
attr.attributeGloss = Generics.newArrayList(this.attributeGloss.size());
for (CoreLabel lbl : this.attributeGloss) {
attr.attributeGloss.add(new CoreLabel(lbl));
}
}
attr.image = this.image;
return attr;
}
public String subjectGloss() {
if (this.subjectGloss == null) return this.text[0];
return StringUtils.join(this.subjectGloss.stream().map(CoreLabel::word), " ");
}
public String subjectLemmaGloss() {
if (this.subjectGloss == null) return this.text[0];
return StringUtils.join(this.subjectGloss.stream().map(x -> x.lemma() == null ? x.word() : x.lemma()), " ");
}
public String attributeGloss() {
if (this.attributeGloss == null) return this.text[2];
return StringUtils.join(this.attributeGloss.stream().map(CoreLabel::word), " ");
}
public String attributeLemmaGloss() {
if (this.attributeGloss == null) return this.text[2];
return StringUtils.join(this.attributeGloss.stream().map(x -> x.lemma() == null ? x.word() : x.lemma()), " ");
}
public void print(PrintStream out) {
out.printf("%s\tis\t%s%n", this.text[0], this.text[2]);
}
@Override
public boolean equals(Object otherObj) {
if (otherObj == null) return false;
if ( ! (otherObj instanceof SceneGraphImageAttribute)) return false;
SceneGraphImageAttribute other = (SceneGraphImageAttribute) otherObj;
if (other.region != this.region) {
return false;
}
if ( ! other.attributeLemmaGloss().equals(attributeLemmaGloss())) {
return false;
}
if ( ! other.subjectLemmaGloss().equals(subjectLemmaGloss())) {
return false;
}
return true;
}
@Override
public int hashCode() {
int[] arr = {this.image.regions.indexOf(this.region),
attributeLemmaGloss().hashCode(), subjectLemmaGloss().hashCode()};
return arr.hashCode();
}
}