In [16]:
import os
import sys
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import pandas as pd
import numpy as np

In [17]:
# Ensure the project root is in the system path
project_root = os.path.abspath(os.getcwd())
if project_root not in sys.path:
    sys.path.append(project_root)

In [18]:
# Paths to data and models
train_csv_path = os.path.join(project_root, "Extracted_SignLanguageMNIST", "sign_mnist_train.csv")
test_csv_path = os.path.join(project_root, "Extracted_SignLanguageMNIST", "sign_mnist_test.csv")

# Create label mapping (0 -> A, 1 -> B, ..., 25 -> Z)
label_mapping = {i: chr(65 + i) for i in range(26)}

In [19]:
# Load train and test datasets
train_data = pd.read_csv(train_csv_path)
test_data = pd.read_csv(test_csv_path)

# Define the number of classes
num_classes = 26

# Load models (caching to avoid reloading)
loaded_models = {}

In [20]:
def load_model(model_type, num_classes, model_path=None, device='cpu'):
    model_type = model_type.lower()
    if model_type == "resnet18":
        model = models.resnet18(weights=None)
    elif model_type == "resnet50":
        model = models.resnet50(weights=None)
    elif model_type == "custom":
        model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(64 * 56 * 56, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    else:
        raise ValueError("Invalid model type. Choose 'resnet18', 'resnet50', or 'custom'.")

    if "resnet" in model_type:
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, num_classes)

    if model_path:
        model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"Model weights loaded from {model_path}")
    else:
        print(f"Using pretrained weights for {model_type}")

    model.eval()
    model.to(device)
    return model

model_paths = {
    "resnet18": os.path.join(project_root, "saved_models", "trained_resnet18.pth"),
    "resnet50": os.path.join(project_root, "saved_models", "trained_resnet50.pth"),
    "custom": os.path.join(project_root, "saved_models", "trained_custom.pth"),
}

In [21]:
model_paths = {
    "resnet18": os.path.join(project_root, "saved_models", "trained_resnet18.pth"),
    "resnet50": os.path.join(project_root, "saved_models", "trained_resnet50.pth"),
    "custom": os.path.join(project_root, "saved_models", "trained_custom.pth"),
}

In [22]:
def fetch_sample_image(sample_index):
    data = train_data
    labels = data.iloc[:, 0].values
    images = data.iloc[:, 1:].values

    if sample_index >= len(images):
        raise IndexError(f"Sample index {sample_index} is out of range.")

    label = labels[sample_index]
    image_data = images[sample_index].reshape(28, 28).astype("uint8")
    image = Image.fromarray(image_data).convert("RGB")
    return image, label

In [23]:
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
    ])
    return transform(image).unsqueeze(0)

In [24]:
def predict_with_model(model_choice, image):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_choice_lower = model_choice.lower()

    # Preprocess the image
    input_tensor = preprocess_image(image).to(device)

    # Load the model from cache or load and cache it
    if model_choice_lower in loaded_models:
        model = loaded_models[model_choice_lower]
    else:
        model_path = model_paths.get(model_choice_lower)
        if not model_path:
            raise ValueError(f"Invalid model choice: {model_choice}")
        model = load_model(model_choice_lower, num_classes=num_classes, model_path=model_path, device=device)
        loaded_models[model_choice_lower] = model  # Cache the loaded model

    # Perform inference
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]
        predicted_class = np.argmax(probabilities)
        confidence_scores = probabilities

    return predicted_class, confidence_scores

In [25]:
# Gradio Interface with improved layout
def create_dashboard():
    with gr.Blocks() as demo:
        gr.Markdown("# Sign Language Recognition Dashboard")
        gr.Markdown("Use this app to explore different models for recognizing sign language letters.")
        
        with gr.Tabs():
            with gr.TabItem("Predict"):
                with gr.Row():
                    with gr.Column(scale=1):
                        gr.Markdown("### Model Selection")
                        model_choice = gr.Radio(
                            ["ResNet18", "ResNet50", "Custom"],
                            label="Model",
                            value="ResNet18"
                        )
                        gr.Markdown("### Select Sample Image")
                        sample_index = gr.Slider(
                            0, len(train_data) - 1,
                            step=1,
                            label="Sample Index",
                            value=0
                        )
                        gr.Markdown("### Or Upload Your Own Image")
                        upload_image = gr.Image(type="pil", label="Upload Image (Optional)")
                    with gr.Column(scale=1):
                        image_display = gr.Image(
                            label="Input Image",
                            type="pil",
                            interactive=False,
                            width=224,
                            height=224
                        )
                        prediction_text = gr.Textbox(
                            label="Prediction",
                            interactive=False
                        )
                        actual_label_text = gr.Textbox(
                            label="Actual Label",
                            interactive=False
                        )
                gr.Markdown("### Prediction Confidence")
                confidence_plot = gr.Plot(
                    label="Confidence Scores"
                )
            with gr.TabItem("About"):
                gr.Markdown("""
                ### About This Application

                This Sign Language Recognition Dashboard is developed to demonstrate the capabilities of different neural network architectures in recognizing American Sign Language (ASL) letters. 
                
                This dashboard was delveloped for a cummulative project for Fanshawe College (London, Onatrio), Deep Learning with Pytorch
                
                The application allows users to:

                - **Explore different models**: Compare the performance of ResNet18, ResNet50, and a custom CNN model.
                - **Visualize predictions**: View the predicted letter and confidence scores for each class.
                - **Upload custom images**: Test the models with your own images of ASL letters that have not been included in the training set!


                #### Dataset

                The models are trained on the [Sign Language MNIST](https://www.kaggle.com/datamunge/sign-language-mnist) dataset, which contains images of ASL letters represented in a format similar to the original MNIST dataset.

                #### Acknowledgments

                - **Dataset**: Thanks to [Kaggle](https://www.kaggle.com/) and the contributors for providing the Sign Language MNIST dataset.
                - **Libraries Used**: PyTorch, Torchvision, Gradio, Plotly, NumPy, and Pandas.

                """)
        
        # Add GitHub link at the bottom
        gr.Markdown("""
        ---
        Developed by [Paige Berrigan](https://github.com/paigeberrigan). View the project on [GitHub](https://github.com/yourusername/yourrepository).
        """)

        # Define the interaction with live updates
        def on_change(model_choice, sample_index, upload_image):
            if upload_image is not None:
                image = upload_image
                actual_label = "N/A"
            else:
                image, actual_label_idx = fetch_sample_image(int(sample_index))
                actual_label = label_mapping.get(actual_label_idx, "Unknown")

            # Resize the image for display purposes
            image_display_resized = image.resize((224, 224), Image.NEAREST)

            predicted_class, confidence_scores = predict_with_model(model_choice, image)
            predicted_letter = label_mapping.get(predicted_class, "Unknown")

            # Prepare the confidence bar graph using Plotly
            labels = [label_mapping[i] for i in range(num_classes)]
            import plotly.graph_objects as go
            fig = go.Figure([go.Bar(x=labels, y=confidence_scores)])
            fig.update_layout(
                title='Prediction Confidence',
                xaxis_title='Classes',
                yaxis_title='Confidence',
                xaxis_tickangle=-45,
                height=400
            )

            return image_display_resized, f"Predicted: {predicted_letter}", f"Actual: {actual_label}", fig

        # make sure the app is updating live
        inputs = [model_choice, sample_index, upload_image]
        outputs = [image_display, prediction_text, actual_label_text, confidence_plot]
        model_choice.change(on_change, inputs=inputs, outputs=outputs)
        sample_index.change(on_change, inputs=inputs, outputs=outputs)
        upload_image.change(on_change, inputs=inputs, outputs=outputs)

    return demo


In [26]:
# launch the dashboard
demo = create_dashboard()
demo.launch()

* Running on local URL:  http://127.0.0.1:7863

To create a public link, set `share=True` in `launch()`.





You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



Model weights loaded from c:\Users\paige\OneDrive\Desktop\6147 - PYTORCH\Capstone_SignLanguageMNIST\saved_models\trained_resnet50.pth
Model weights loaded from c:\Users\paige\OneDrive\Desktop\6147 - PYTORCH\Capstone_SignLanguageMNIST\saved_models\trained_resnet18.pth
Model weights loaded from c:\Users\paige\OneDrive\Desktop\6147 - PYTORCH\Capstone_SignLanguageMNIST\saved_models\trained_custom.pth
