In [4]:
import torch
import os
import cv2
from PIL import Image
from torchvision import transforms

# Load the models
condition_classifier = torch.load('../VGG19/model_vgg19_adverse_env.pth')  # Load your condition classifier model
enhance_net_rain = torch.load('../EnhanceNet/enhance_net_rain.pth')  # Load your EnhanceNet model for rain
enhance_net_snow = torch.load('../EnhanceNet/enhance_net_snow.pth')  # Load your EnhanceNet model for snow
enhance_net_haze = torch.load('../EnhanceNet/enhance_net_haze.pth')  # Load your EnhanceNet model for haze
enhance_net_shadow = torch.load('../EnhanceNet/enhance_net_shadow.pth')  # Load your EnhanceNet model for shadow
enhance_net_lens_blur = torch.load('../EnhanceNet/enhance_net_lens_blur.pth')  # Load your EnhanceNet model for lens blur
recognition_model = torch.hub.load('ultralytics/yolov5', 'custom', path='../Yolo/yolov5s.pt', force_reload=True)  # Load your YOLOv5 model

# Preprocessing function for the classifier
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    return transform(image).unsqueeze(0)  # Add batch dimension

def classify_condition(image):
    """Classify the condition of the input image using the trained classifier."""
    input_tensor = preprocess_image(image)
    output = condition_classifier(input_tensor)
    _, predicted = torch.max(output, 1)
    return predicted.item()  # Return the class index

def enhance_image(image, condition):
    """Enhance the image based on the detected condition."""
    image_tensor = preprocess_image(image).to(device)

    if condition == 2:  # Rain
        enhanced_image = enhance_net_rain(image_tensor)
    elif condition == 4:  # Snow
        enhanced_image = enhance_net_snow(image_tensor)
    elif condition == 0:  # Haze
        enhanced_image = enhance_net_haze(image_tensor)
    elif condition == 3:  # Shadow
        enhanced_image = enhance_net_shadow(image_tensor)
    elif condition == 1:  # Lens Blur
        enhanced_image = enhance_net_lens_blur(image_tensor)
    else:
        return image_tensor  # No enhancement if no condition is detected
    
    return enhanced_image.detach().cpu()  # Return enhanced image

def detect_sign(image):
    """Detect traffic signs using YOLOv5."""
    results = recognition_model(image)
    return results

def process_image(image_path):
    """Main function to process the input image."""
    image = Image.open(image_path).convert("RGB")

    # Classify the condition
    condition = classify_condition(image)

    if condition in [0, 1, 2, 3, 4]:  # If a condition is detected
        enhanced_image = enhance_image(image, condition)
        results = detect_sign(enhanced_image)
    else:  # No conditions detected
        results = detect_sign(image)

    return results

def process_video(video_path):
    """Process video frames."""
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Convert BGR to RGB
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = process_image(image)

        # Display results (optional)
        # Process results for visualization if necessary

    cap.release()

# Input: Either an image or a video
input_type = input("Enter 'image' for single image or 'video' for video file: ").strip().lower()
if input_type == 'image':
    image_path = input("Enter the path to the image: ")
    results = process_image(image_path)
    print(results)  # Output results from detection
elif input_type == 'video':
    video_path = input("Enter the path to the video: ")
    process_video(video_path)
else:
    print("Invalid input type.")

  condition_classifier = torch.load('../VGG19/model_vgg19_adverse_env.pth')  # Load your condition classifier model
  enhance_net_rain = torch.load('../EnhanceNet/enhance_net_rain.pth')  # Load your EnhanceNet model for rain
  enhance_net_snow = torch.load('../EnhanceNet/enhance_net_snow.pth')  # Load your EnhanceNet model for snow
  enhance_net_haze = torch.load('../EnhanceNet/enhance_net_haze.pth')  # Load your EnhanceNet model for haze
  enhance_net_shadow = torch.load('../EnhanceNet/enhance_net_shadow.pth')  # Load your EnhanceNet model for shadow
  enhance_net_lens_blur = torch.load('../EnhanceNet/enhance_net_lens_blur.pth')  # Load your EnhanceNet model for lens blur
Downloading: "https://github.com/ultralytics/yolov5/zipball/master" to /Users/taufiqnoorani/.cache/torch/hub/master.zip
YOLOv5 🚀 2024-9-27 Python-3.9.20 torch-2.6.0.dev20240923 CPU

Fusing layers... 
YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients, 16.4 GFLOPs
Adding AutoShape... 


Enter 'image' for single image or 'video' for video file:  123.bmp


Invalid input type.
