This repository has been archived by the owner on May 3, 2022. It is now read-only.
/
AbstractOryxResource.java
182 lines (159 loc) · 6.51 KB
/
AbstractOryxResource.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
/*
* Copyright (c) 2014, Cloudera and Intel, Inc. All Rights Reserved.
*
* Cloudera, Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"). You may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the
* License.
*/
package com.cloudera.oryx.app.serving;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.zip.GZIPInputStream;
import java.util.zip.ZipInputStream;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.Part;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.Response;
import com.google.common.base.Preconditions;
import org.apache.commons.fileupload.FileUploadException;
import org.apache.commons.fileupload.disk.DiskFileItemFactory;
import org.apache.commons.fileupload.servlet.FileCleanerCleanup;
import org.apache.commons.fileupload.servlet.ServletFileUpload;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.cloudera.oryx.api.TopicProducer;
import com.cloudera.oryx.api.serving.OryxResource;
import com.cloudera.oryx.api.serving.OryxServingException;
import com.cloudera.oryx.api.serving.ServingModel;
/**
* Superclass of all Serving Layer application endpoints.
*/
public abstract class AbstractOryxResource extends OryxResource {
private static final Logger log = LoggerFactory.getLogger(AbstractOryxResource.class);
private static final AtomicReference<DiskFileItemFactory> sharedFileItemFactory =
new AtomicReference<>();
@Context
private ServletContext servletContext;
private boolean hasLoadedEnough;
protected final void sendInput(String message) {
@SuppressWarnings("unchecked")
TopicProducer<String,String> inputProducer = (TopicProducer<String,String>) getInputProducer();
inputProducer.send(Integer.toHexString(message.hashCode()), message);
}
protected final boolean isReadOnly() {
return getServingModelManager().isReadOnly();
}
protected final ServingModel getServingModel() throws OryxServingException {
ServingModel servingModel = getServingModelManager().getModel();
if (hasLoadedEnough) {
Objects.requireNonNull(servingModel);
return servingModel;
}
if (servingModel != null) {
double minModelLoadFraction = getServingModelManager().getConfig()
.getDouble("oryx.serving.min-model-load-fraction");
Preconditions.checkArgument(minModelLoadFraction >= 0.0 && minModelLoadFraction <= 1.0);
float fractionLoaded = servingModel.getFractionLoaded();
log.info("Model loaded fraction: {}", fractionLoaded);
if (fractionLoaded >= minModelLoadFraction) {
hasLoadedEnough = true;
}
}
if (hasLoadedEnough) {
Objects.requireNonNull(servingModel);
return servingModel;
} else {
throw new OryxServingException(Response.Status.SERVICE_UNAVAILABLE);
}
}
protected final Collection<Part> parseMultipart(HttpServletRequest request) throws OryxServingException {
Collection<Part> parts;
try {
try {
// Prefer container's standard JavaEE multipart parsing:
parts = request.getParts();
} catch (UnsupportedOperationException uoe) {
// Grizzly (used in tests) doesn't support this; fall back until it does
parts = parseMultipartWithCommonsFileUpload(request);
}
} catch (IOException | ServletException e) {
throw new OryxServingException(Response.Status.BAD_REQUEST, e.getMessage());
}
check(!parts.isEmpty(), "No parts");
return parts;
}
private Collection<Part> parseMultipartWithCommonsFileUpload(HttpServletRequest request) throws IOException {
if (sharedFileItemFactory.get() == null) {
// Not a big deal if two threads actually set this up
DiskFileItemFactory fileItemFactory = new DiskFileItemFactory(
1 << 16, (File) servletContext.getAttribute("javax.servlet.context.tempdir"));
fileItemFactory.setFileCleaningTracker(
FileCleanerCleanup.getFileCleaningTracker(servletContext));
sharedFileItemFactory.compareAndSet(null, fileItemFactory);
}
try {
return new ServletFileUpload(sharedFileItemFactory.get()).parseRequest(request)
.stream().map(FileItemPart::new).collect(Collectors.toList());
} catch (FileUploadException e) {
throw new IOException(e.getMessage());
}
}
protected static void check(boolean condition,
Response.Status status,
String errorMessage) throws OryxServingException {
if (!condition) {
throw new OryxServingException(status, errorMessage);
}
}
protected static void check(boolean condition,
String errorMessage) throws OryxServingException {
check(condition, Response.Status.BAD_REQUEST, errorMessage);
}
protected static void checkExists(boolean condition,
String entity) throws OryxServingException {
check(condition, Response.Status.NOT_FOUND, entity);
}
protected void checkNotReadOnly() throws OryxServingException {
check(!isReadOnly(), Response.Status.FORBIDDEN, "Serving Layer is read-only");
}
protected static BufferedReader maybeBuffer(InputStream in) {
return maybeBuffer(new InputStreamReader(in, StandardCharsets.UTF_8));
}
protected static BufferedReader maybeBuffer(Reader reader) {
return reader instanceof BufferedReader ? (BufferedReader) reader : new BufferedReader(reader);
}
protected static InputStream maybeDecompress(Part item) throws IOException {
InputStream in = item.getInputStream();
String contentType = item.getContentType();
if (contentType != null) {
switch (contentType) {
case "application/zip":
in = new ZipInputStream(in);
break;
case "application/gzip":
case "application/x-gzip":
in = new GZIPInputStream(in);
break;
}
}
return in;
}
}