In [10]:
import torch
from torchvision import transforms, models
from PIL import Image

# Define the model architecture (same as used during training)
model = models.googlenet(weights=None, aux_logits=False)
model.fc = torch.nn.Linear(model.fc.in_features, 1)

# Load the saved model weights
def load_model(file_path='banana_freshness_model.pth'):
    model.load_state_dict(torch.load(file_path))
    model.eval()  # Set model to evaluation mode
    print(f"Model loaded from {file_path}")
    return model

# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model = load_model('banana_freshness_model.pth')

# Inference function to predict freshness index
def predict(model, image):
    model.eval()
    with torch.no_grad():
        if isinstance(image, Image.Image):  # If it's a PIL image, transform it
            image = data_transforms(image).unsqueeze(0)  # Apply transforms and add batch dimension
        else:
            image = image.unsqueeze(0)  # Add batch dimension if it's already a tensor

        image = image.to(device)  # Move image to GPU or CPU
        output = model(image)
        return output.item()

# Data transformation for inference
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load and preprocess the image for prediction
image_path = 'data/day6/IMG20240920033429.jpg'
image = Image.open(image_path).convert('RGB')
image = data_transforms(image)

# Predict the freshness index
freshness_index = predict(model, image)
print(f"Predicted Freshness Index: {1-freshness_index/7:.2f}")

Model loaded from banana_freshness_model.pth
Predicted Freshness Index: 0.17


  model.load_state_dict(torch.load(file_path))
