import torch import io import torchvision import torchvision.transforms as transforms from PIL import Image import time start_time = time.time() def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0) import json imagenet_class_index = json.load(open('/Users/harsh_bafna/serve/examples/image_classifier/index_to_name.json')) def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx] model = torchvision.models.densenet161() state_dict = torch.load('/Users/harsh_bafna/state_dicts/densenet161-8d451a50.pth') import re pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) model.eval() with open("/Users/harsh_bafna/test_images/kitten.jpg", 'rb') as f: image_bytes = f.read() print(get_prediction(image_bytes=image_bytes)) print("--- %s seconds ---" % (time.time() - start_time))