In [None]:
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
import json
import torch.nn as nn
from torchvision import models
import os

# Set page config
st.set_page_config(
    page_title="Sports Image Classifier",
    page_icon="🏃",
    layout="centered"
)

# Custom CSS
st.markdown("""
    <style>
    .main {
        background-color: #f5f5f5;
    }
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        padding: 10px 24px;
        border-radius: 8px;
        border: none;
        font-size: 16px;
    }
    .stButton>button:hover {
        background-color: #45a049;
    }
    .prediction-box {
        background-color: white;
        padding: 20px;
        border-radius: 10px;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        margin-top: 20px;
    }
    </style>
    """, unsafe_allow_html=True)

# Image preprocessing
def preprocess_image(image):
    # Base transform
    base_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Test-time augmentation transforms
    tta_transforms = [
        transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ]
    
    # Apply base transform
    base_tensor = base_transform(image).unsqueeze(0)
    
    # Apply TTA transforms
    tta_tensors = [transform(image).unsqueeze(0) for transform in tta_transforms]
    
    return base_tensor, tta_tensors

# Load the model and class mapping
@st.cache_resource
def load_model():
    # Load class mapping
    with open('class_mapping.json', 'r') as f:
        class_mapping = json.load(f)
    
    # Initialize model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = models.resnet50(pretrained=False)
    num_classes = len(class_mapping)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    # Load trained weights
    checkpoint = torch.load('best_model.pth', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    return model, class_mapping

# Main app
def main():
    st.title("🏃 Sports Image Classifier")
    st.write("Upload an image of a sport to classify it!")
    
    try:
        model, class_mapping = load_model()
    except Exception as e:
        st.error("Error loading model. Please make sure you have trained the model first!")
        st.error(str(e))
        return
    
    # File uploader
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    
    if uploaded_file is not None:
        # Display the uploaded image
        image = Image.open(uploaded_file).convert('RGB')
        st.image(image, caption="Uploaded Image", use_column_width=True)
        
        # Add predict button
        if st.button("Predict"):
            with st.spinner("Analyzing image..."):
                # Preprocess image with TTA
                base_tensor, tta_tensors = preprocess_image(image)
                
                # Get prediction
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                base_tensor = base_tensor.to(device)
                tta_tensors = [tensor.to(device) for tensor in tta_tensors]
                
                with torch.no_grad():
                    # Get base prediction
                    base_outputs = model(base_tensor)
                    base_probs = torch.nn.functional.softmax(base_outputs[0], dim=0)
                    
                    # Get TTA predictions
                    tta_outputs = [model(tensor) for tensor in tta_tensors]
                    tta_probs = [torch.nn.functional.softmax(output[0], dim=0) for output in tta_outputs]
                    
                    # Average all predictions
                    all_probs = [base_probs] + tta_probs
                    avg_probs = torch.stack(all_probs).mean(0)
                    
                    # Get top 3 predictions
                    top_prob, top_class = torch.topk(avg_probs, 3)
                
                # Display results
                st.markdown("### Prediction Results")
                st.markdown('<div class="prediction-box">', unsafe_allow_html=True)
                
                for prob, class_idx in zip(top_prob, top_class):
                    sport_name = class_mapping[str(class_idx.item())]
                    confidence = prob.item() * 100
                    st.write(f"**{sport_name}**: {confidence:.2f}%")
                    
                    # Progress bar
                    st.progress(confidence/100)
                
                st.markdown('</div>', unsafe_allow_html=True)

if __name__ == "__main__":
    main() 
    