Skip to content

Commit

Permalink
Fix interpretation of filename from model archive URL (#2416)
Browse files Browse the repository at this point in the history
Co-authored-by: Naman Nandan <namannan@amazon.com>
  • Loading branch information
namannandan and Naman Nandan committed Jun 15, 2023
1 parent c2cdcfb commit 679b33d
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.archive.utils.InvalidArchiveURLException;
Expand Down Expand Up @@ -55,7 +54,7 @@ public static ModelArchive downloadModel(
throw new ModelNotFoundException("empty url");
}

String marFileName = FilenameUtils.getName(url);
String marFileName = ArchiveUtils.getFilenameFromUrl(url);
File modelLocation = new File(modelStore, marFileName);
try {
ArchiveUtils.downloadArchive(
Expand Down Expand Up @@ -165,7 +164,7 @@ public void validate() throws InvalidModelException {

public static void removeModel(String modelStore, String marURL) {
if (ArchiveUtils.isValidURL(marURL)) {
String marFileName = FilenameUtils.getName(marURL);
String marFileName = ArchiveUtils.getFilenameFromUrl(marURL);
File modelLocation = new File(modelStore, marFileName);
FileUtils.deleteQuietly(modelLocation);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileAlreadyExistsException;
Expand All @@ -16,6 +17,7 @@
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.InvalidModelException;
import org.pytorch.serve.archive.s3.HttpUtils;
Expand Down Expand Up @@ -90,6 +92,15 @@ public static boolean isValidURL(String url) {
return VALID_URL_PATTERN.matcher(url).matches();
}

public static String getFilenameFromUrl(String url) {
try {
URL archiveUrl = new URL(url);
return FilenameUtils.getName(archiveUrl.getPath());
} catch (MalformedURLException e) {
return FilenameUtils.getName(url);
}
}

public static boolean downloadArchive(
List<String> allowedUrls,
File location,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import java.nio.file.Files;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.archive.utils.InvalidArchiveURLException;
Expand Down Expand Up @@ -53,7 +52,7 @@ public static WorkflowArchive downloadWorkflow(
throw new WorkflowNotFoundException("Workflow store has not been configured.");
}

String warFileName = FilenameUtils.getName(url);
String warFileName = ArchiveUtils.getFilenameFromUrl(url);
File workflowLocation = new File(workflowStore, warFileName);

try {
Expand Down Expand Up @@ -144,7 +143,7 @@ public void validate() throws InvalidWorkflowException {

public static void removeWorkflow(String workflowStore, String warURL) {
if (ArchiveUtils.isValidURL(warURL)) {
String warFileName = FilenameUtils.getName(warURL);
String warFileName = ArchiveUtils.getFilenameFromUrl(warURL);
File workflowLocation = new File(workflowStore, warFileName);
FileUtils.deleteQuietly(workflowLocation);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.pytorch.serve.archive.utils;

import org.testng.Assert;
import org.testng.annotations.Test;

public class ArchiveUtilsTest {
@Test
public void testGetFilenameFromUrlWithFilename() {
String testFilename = "resnet-18.mar";
String expectedFilename = "resnet-18.mar";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFilename), expectedFilename);
}

@Test
public void testGetFilenameFromUrlWithFilepath() {
String testFilepath = "/home/ubuntu/model_store/resnet-18.mar";
String expectedFilename = "resnet-18.mar";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFilepath), expectedFilename);
}

@Test
public void testGetFilenameFromUrlWithUrl() {
String testFileUrl = "https://torchserve.pytorch.org/mar_files/resnet-18.mar";
String expectedFilename = "resnet-18.mar";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFileUrl), expectedFilename);
}

@Test
public void testGetFilenameFromUrlWithS3PresignedUrl() {
String testFileUrl =
"https://test-account.s3.us-west-2.amazonaws.com/mar_files/resnet-18.mar?"
+ "response-content-disposition=inline&X-Amz-Security-Token=%2Ftoken%2F"
+ "&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20230614T182131Z&X-Amz-SignedHeaders=host"
+ "&X-Amz-Expires=43200&X-Amz-Credential=%2Fcredential%2F"
+ "&X-Amz-Signature=%2Fsignature%2F";
String expectedFilename = "resnet-18.mar";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFileUrl), expectedFilename);
}

@Test
public void testGetFilenameFromUrlWithInvalidUrl() {
String testFileUrl = "resnet-18.mar/";
String expectedFilename = "";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFileUrl), expectedFilename);
}
}
3 changes: 2 additions & 1 deletion frontend/archive/testng.xml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
<!DOCTYPE suite SYSTEM "https://testng.org/testng-1.0.dtd" >

<suite name="ModelArchiverSuite" verbose="1" >
<test name="TorchServe">
<classes>
<class name="org.pytorch.serve.archive.CoverageTest"/>
<class name="org.pytorch.serve.archive.model.ModelArchiveTest"/>
<class name="org.pytorch.serve.archive.model.ModelConfigTest"/>
<class name="org.pytorch.serve.archive.utils.ArchiveUtilsTest"/>
<class name="org.pytorch.serve.archive.workflow.WorkFlowArchiveTest"/>
</classes>
</test>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.Manifest;
import org.pytorch.serve.archive.model.ModelArchive;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.http.BadRequestException;
import org.pytorch.serve.http.InternalServerException;
import org.pytorch.serve.http.InvalidModelVersionException;
Expand Down Expand Up @@ -183,7 +183,7 @@ public static StatusResponse handleRegister(
s3SseKms);
} catch (FileAlreadyExistsException e) {
throw new InternalServerException(
"Model file already exists " + FilenameUtils.getName(modelUrl), e);
"Model file already exists " + ArchiveUtils.getFilenameFromUrl(modelUrl), e);
} catch (IOException | InterruptedException e) {
throw new InternalServerException("Failed to save model: " + modelUrl, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.model.ModelArchive;
import org.pytorch.serve.archive.model.ModelConfig;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.messages.WorkerCommands;
Expand Down Expand Up @@ -130,7 +130,7 @@ public JsonObject getModelState(boolean isDefaultVersion) {

JsonObject modelInfo = new JsonObject();
modelInfo.addProperty(DEFAULT_VERSION, isDefaultVersion);
modelInfo.addProperty(MAR_NAME, FilenameUtils.getName(getModelUrl()));
modelInfo.addProperty(MAR_NAME, ArchiveUtils.getFilenameFromUrl(getModelUrl()));
modelInfo.addProperty(MIN_WORKERS, getMinWorkers());
modelInfo.addProperty(MAX_WORKERS, getMaxWorkers());
modelInfo.addProperty(BATCH_SIZE, getBatchSize());
Expand Down

0 comments on commit 679b33d

Please sign in to comment.