In [6]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import io

# Define the same model architecture used during training
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(2, 224 * 224 * 3))
        self.bias = nn.Parameter(torch.Tensor(2))
    
    def forward(self, x):
        return torch.addmm(self.bias, x, self.weight.t())

# Function to load the trained model
def model_fn(model_dir):
    # Load the trained model (same as the one used during training)
    model = SimpleModel()
    model.load_state_dict(torch.load(f'{model_dir}/model.pth', map_location=torch.device('cpu')))
    model.eval()
    return model

# Function to preprocess the input image
def input_fn(request_body, content_type='image/jpeg'):
    if content_type == 'image/jpeg':
        img = Image.open(io.BytesIO(request_body))
        preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        img_t = preprocess(img)
        img_t = img_t.view(-1, 224 * 224 * 3)  # Flatten the image to match the linear layer input size
        return img_t
    else:
        raise ValueError(f'Unsupported content type: {content_type}')

# Function to perform inference
def predict_fn(input_data, model):
    with torch.no_grad():
        output = model(input_data)
        _, predicted_class = torch.max(output, 1)
    return predicted_class.item()

# Function to handle the prediction request locally
def predict_local(image_path, model_dir):
    # Load the model
    model = model_fn(model_dir)

    # Read the image and convert it to bytes
    with open(image_path, 'rb') as f:
        img_bytes = f.read()

    # Preprocess the image
    input_data = input_fn(img_bytes)

    # Make a prediction
    result = predict_fn(input_data, model)

    print(f"Predicted class: {result}")
    return result

if __name__ == "__main__":
    # Test the local prediction
    image_path = '/home/ec2-user/SageMaker/2023_0512_121225_001.JPG'  # Update with the path to your local image
    model_dir = '/home/ec2-user/SageMaker'  # Update with the path where the model is saved
    predict_local(image_path, model_dir)


Predicted class: 1
