Skip to content

Commit 2482979

Browse files
okram999Niris Okram
and
Niris Okram
authored
change: Add example to import custom model in SageMaker JumpStart private model hub (#4824)
* Adding the sample for importing custom model to a jumpstart private model hub * file renamed and updated to align with PR guidelines * renamed folder to align with PR guidelines * updated description in the notebook * switching to torch.inference_mode() from torch.no_grad() to enhance performance --------- Co-authored-by: Niris Okram <niris@amazon.com>
1 parent 2d519b5 commit 2482979

File tree

2 files changed

+1813
-0
lines changed

2 files changed

+1813
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
import torch.nn as nn
3+
import os
4+
import json
5+
6+
7+
# Define the ANN model
8+
class ANN(nn.Module):
9+
def __init__(self):
10+
super(ANN, self).__init__()
11+
self.fc1 = nn.Linear(2, 64)
12+
self.fc2 = nn.Linear(64, 32)
13+
self.fc3 = nn.Linear(32, 1)
14+
self.sigmoid = nn.Sigmoid()
15+
16+
def forward(self, x):
17+
x = self.fc1(x)
18+
x = torch.relu(x)
19+
x = self.fc2(x)
20+
x = torch.relu(x)
21+
x = self.fc3(x)
22+
x = self.sigmoid(x)
23+
return x
24+
25+
26+
def model_fn(model_dir):
27+
"""Load the PyTorch model from the model_dir."""
28+
model = ANN()
29+
model.load_state_dict(torch.load(os.path.join(model_dir, "model.pth")))
30+
model.eval()
31+
return model
32+
33+
34+
def input_fn(request_body, content_type):
35+
"""Process the incoming request body."""
36+
if content_type == "application/json":
37+
data = json.loads(request_body)
38+
return torch.tensor(data["features"], dtype=torch.float32)
39+
else:
40+
raise ValueError(f"Unsupported content type: {content_type}")
41+
42+
43+
def output_fn(prediction, accept):
44+
"""Format the model's prediction for the response."""
45+
if accept == "application/json":
46+
return {"output": prediction.item()}
47+
else:
48+
raise ValueError(f"Unsupported accept type: {accept}")
49+
50+
51+
def predict_fn(input_data, model):
52+
"""Perform the prediction using the loaded model."""
53+
with torch.inference_mode():
54+
output = model(input_data)
55+
y_pred = (output > 0.5).float()
56+
return y_pred

build_and_train_models/sm-jumpstart_private_model_hub_import/sm-jumpstart_private_model_hub_import.ipynb

Lines changed: 1757 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)