In [3]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import joblib

# Load the models
models_paths = {
    'shape': 'diamond_cnn_shape.pth',
    'cut': 'diamond_cnn_cut.pth',
    'color': 'diamond_cnn_color.pth',
    'clarity': 'diamond_cnn_clarity.pth',
    'polish': 'diamond_cnn_polish.pth',
    'symmetry': 'diamond_cnn_symmetry.pth',
    'fluorescence': 'diamond_cnn_fluorescence.pth'
}

loaded_models = {}
for feature, path in models_paths.items():
    model = models.resnet50(pretrained=False)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, len(joblib.load(f'label_encoder_{feature}.pkl').classes_))
    model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
    model.eval()
    loaded_models[feature] = model

# Load the label encoders
label_encoders = {
    'shape': joblib.load('label_encoder_shape.pkl'),
    'cut': joblib.load('label_encoder_cut.pkl'),
    'color': joblib.load('label_encoder_color.pkl'),
    'clarity': joblib.load('label_encoder_clarity.pkl'),
    'polish': joblib.load('label_encoder_polish.pkl'),
    'symmetry': joblib.load('label_encoder_symmetry.pkl'),
    'fluorescence': joblib.load('label_encoder_fluorescence.pkl')
}

# Preprocess the image
def preprocess_image(image_path):
    data_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert('RGB')
    image = data_transforms(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image

# Predict function
def predict(image_path):
    image = preprocess_image(image_path)
    predictions = {}
    for feature, model in loaded_models.items():
        outputs = model(image)
        _, preds = torch.max(outputs, 1)
        label = label_encoders[feature].inverse_transform(preds.cpu().numpy())[0]
        predictions[feature] = label
    return predictions

# Example usage
image_path = '111000-7314.jpg'
predictions = predict(image_path)
for feature, prediction in predictions.items():
    print(f'{feature}: {prediction}')


shape: princess
cut: EX
color: G
clarity: SI1
polish: EX
symmetry: VG
fluorescence: N
