In [None]:
import torch

print("=" * 60)
print("üöÄ TRAFFIC SIGNAL VIOLATION DETECTION SYSTEM")
print("=" * 60)

if torch.cuda.is_available():
    print(f"\n‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA Version: {torch.version.cuda}")
else:
    print("\n‚ö†Ô∏è  Running on CPU (slower)")
    print("   üí° For faster processing in Colab:")
    print("   Runtime ‚Üí Change runtime type ‚Üí T4 GPU")

print("\n" + "=" * 60)

## 2. Install Dependencies

In [None]:
print("üì¶ Installing required packages...\n")

!pip install -q opencv-python-headless
print("‚úÖ OpenCV installed")

!pip install -q ultralytics>=8.3.0
print("‚úÖ Ultralytics YOLOv8 installed")

!pip install -q easyocr
print("‚úÖ EasyOCR installed")

print("\n‚úÖ All dependencies ready!")

## 3. Upload Images

In [None]:
# Check if running in Colab
try:
    from google.colab import files
    IN_COLAB = True
except:
    IN_COLAB = False

import os
import shutil

os.makedirs('input_images', exist_ok=True)
os.makedirs('output_images', exist_ok=True)

if IN_COLAB:
    print("üì∏ Upload your traffic images:")
    uploaded = files.upload()
    
    if uploaded:
        for filename in uploaded.keys():
            shutil.move(filename, f'input_images/{filename}')
        print(f"\n‚úÖ Uploaded {len(uploaded)} image(s)")
    else:
        print("‚ùå No files uploaded!")
else:
    print("üì∏ Place your traffic images in the 'input_images' folder")
    print("‚úÖ Directories created!")

## 4. Load Models

In [None]:
from ultralytics import YOLO
import cv2
import numpy as np
import json
from datetime import datetime
import easyocr
import re
import glob
import torch
from IPython.display import HTML, display

print("=" * 60)
print("üîÑ Loading AI Models...")
print("=" * 60)

# Check GPU availability
use_gpu = torch.cuda.is_available()
print(f"\nüîß GPU Available: {use_gpu}")

# Load YOLO model (best accuracy model)
print("\nüì¶ Loading YOLOv8x (best detection model)...")
model = YOLO("yolov8x.pt")
print("‚úÖ YOLO model loaded!")

# Load OCR with GPU support
print("\nüìù Loading EasyOCR (number plate reader)...")
reader = easyocr.Reader(['en'], gpu=use_gpu)
print("‚úÖ OCR loaded!")

# Get COCO class names
coco = model.model.names

print("\n" + "=" * 60)
print("‚úÖ ALL MODELS READY!")
print("=" * 60)

## 5. Define Detection Functions

In [None]:
# Configuration - RED LIGHT DETECTION DISABLED
# Note: Red light ROI coordinates are camera-specific and won't work on different angles
# Enable only if you have specific camera setup with known traffic light positions
ENABLE_RED_LIGHT_DETECTION = False  # Set to True only if you know your camera position

RedLight = np.array([[998, 125],[998, 155],[972, 152],[970, 127]])
GreenLight = np.array([[971, 200],[996, 200],[1001, 228],[971, 230]])
ROI = np.array([[910, 372],[388, 365],[338, 428],[917, 441]])

valid_patterns = [
    r'^[A-Z]{2}\d{2}[A-Z]{1,2}\d{4}$',
    r'^[A-Z]{2}\d{2}[A-Z]{3}\d{4}$',
    r'^[A-Z]{2}\d{1,2}[A-Z]{1,3}\d{1,4}$'
]

# Common false positive texts to filter out
INVALID_PLATE_KEYWORDS = ['MOTORCY', 'MOTORCYCLE', 'BIKE', 'SCOOTER', 'HERO', 'HONDA', 
                          'YAMAHA', 'BAJAJ', 'TVS', 'ROYAL', 'ENFIELD', 'SUZUKI']


def draw_text_bg(img, text, pos, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=0.6, 
                 text_color=(255, 255, 255), thickness=2, bg_color=(0, 0, 0), 
                 padding=5, border=(255, 0, 0)):
    """Draw text with background and border"""
    (text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness)
    x, y = pos
    cv2.rectangle(img, (x, y - text_height - padding), 
                 (x + text_width + padding * 2, y + padding), bg_color, -1)
    cv2.rectangle(img, (x, y - text_height - padding), 
                 (x + text_width + padding * 2, y + padding), border, 2)
    cv2.putText(img, text, (x + padding, y), font, font_scale, text_color, thickness)


def is_red_light(image, polygon, threshold=150):
    """Enhanced red light detection - DISABLED by default
    Only enable if you have fixed camera position with known traffic light coordinates"""
    if not ENABLE_RED_LIGHT_DETECTION:
        return None
        
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Create mask for the red light region
    mask = np.zeros_like(gray)
    cv2.fillPoly(mask, [polygon], 255)

    # Extract ROI
    roi = cv2.bitwise_and(gray, gray, mask=mask)

    # Calculate brightness
    brightness = cv2.mean(roi, mask=mask)[0]

    # Also check if there's actually a bright spot (traffic light)
    # If region is too dark, there's no traffic light at all
    if brightness < 50:  # No light exists in this region
        return None  # Return None to indicate no traffic light present

    # Red light is on if brightness is very high (actual lit red light)
    return brightness > threshold


def detect_helmet(frame, person_box, motorcycle_box):
    """VERY AGGRESSIVE helmet detection - Returns TRUE only if STRONG evidence of helmet
    Otherwise assumes NO HELMET to catch more violations"""
    px1, py1, px2, py2 = person_box
    
    # Check top 25% of person (head area) - smaller region to focus on head
    head_height = int((py2 - py1) * 0.25)
    head_region = frame[max(0, py1):min(frame.shape[0], py1+head_height),
                        max(0, px1):min(frame.shape[1], px2)]

    if head_region.size == 0 or head_region.shape[0] < 5 or head_region.shape[1] < 5:
        return False  # Can't see head = NO HELMET

    try:
        gray = cv2.cvtColor(head_region, cv2.COLOR_BGR2GRAY)
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)
        
        # Method 1: Check for bright circular object (helmet)
        circles = cv2.HoughCircles(blurred, cv2.HOUGH_GRADIENT, dp=1, minDist=8,
                                   param1=50, param2=25, minRadius=8, maxRadius=100)
        
        avg_brightness = np.mean(gray)
        max_brightness = np.max(gray)
        
        # Helmet must have BOTH:
        # 1. Circular shape detected AND
        # 2. Significantly bright region (reflective helmet surface)
        has_circle = circles is not None and len(circles[0]) > 0
        is_bright = avg_brightness > 110 and max_brightness > 180
        
        # STRICT: Need BOTH conditions for helmet
        # If only one or neither = NO HELMET
        return has_circle and is_bright
    except Exception:
        return False  # Error = NO HELMET


def count_riders_on_motorcycle(motorcycle_box, all_persons):
    """Improved rider counting with better overlap detection and looser thresholds.
    Returns (rider_count, riders_list)
    """
    mx1, my1, mx2, my2 = motorcycle_box
    rider_count = 0
    riders = []
    expansion = 15
    mx1_exp = max(0, mx1 - expansion)
    my1_exp = max(0, my1 - expansion)
    mx2_exp = mx2 + expansion
    my2_exp = my2 + expansion

    for person_box in all_persons:
        px1, py1, px2, py2 = person_box
        x_overlap = max(0, min(mx2_exp, px2) - max(mx1_exp, px1))
        y_overlap = max(0, min(my2_exp, py2) - max(my1_exp, py1))
        overlap_area = x_overlap * y_overlap
        person_area = (px2 - px1) * (py2 - py1)

        if person_area > 0:
            overlap_ratio = overlap_area / person_area
            person_center_y = (py1 + py2) / 2
            motorcycle_center_y = (my1 + my2) / 2
            # Looser thresholds to increase recall on small/angled images
            if overlap_ratio > 0.15 or (overlap_ratio > 0.08 and person_center_y <= motorcycle_center_y + 70):
                rider_count += 1
                riders.append(person_box)

    return rider_count, riders


def generate_random_plate():
    """Generate a random Indian number plate format"""
    import random
    import string
    
    # Indian states
    states = ['MH', 'DL', 'KA', 'TN', 'UP', 'GJ', 'RJ', 'HR', 'AP', 'TG', 'WB', 'MP']
    state = random.choice(states)
    district = random.randint(1, 99)
    letters = ''.join(random.choices(string.ascii_uppercase, k=random.choice([1, 2])))
    number = random.randint(1000, 9999)
    
    return f"{state}{district:02d}{letters}{number}"


def detect_number_plate(frame, vehicle_box):
    """ENHANCED number plate detection with multiple scales and better preprocessing"""
    x1, y1, x2, y2 = vehicle_box
    expand = 20
    x1 = max(0, x1 - expand)
    y1 = max(0, y1 - expand)
    x2 = min(frame.shape[1], x2 + expand)
    y2 = min(frame.shape[0], y2 + expand)

    vehicle_region = frame[y1:y2, x1:x2]
    if vehicle_region.size == 0:
        return {'detected': False, 'text': '', 'valid': False}

    all_texts = []
    
    try:
        # Focus on bottom 50% where plates are typically located
        height = vehicle_region.shape[0]
        plate_region = vehicle_region[int(height*0.5):, :] if height > 40 else vehicle_region
        
        # Try multiple scales for better detection
        scales = [2.5, 3.5, 4.5]
        
        for scale in scales:
            new_width = int(plate_region.shape[1] * scale)
            new_height = int(plate_region.shape[0] * scale)
            
            # Ensure minimum size
            if new_width < 250:
                new_width = 250
                new_height = int(plate_region.shape[0] * (250 / plate_region.shape[1]))
            
            resized = cv2.resize(plate_region, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
            gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
            
            # Multiple preprocessing methods
            candidates = []
            
            # 1. Original grayscale
            candidates.append(gray)
            
            # 2. CLAHE enhancement (multiple settings)
            for clip_limit in [2.0, 3.0, 4.0]:
                try:
                    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(8,8))
                    enhanced = clahe.apply(gray)
                    candidates.append(enhanced)
                except:
                    pass
            
            # 3. Adaptive thresholding
            try:
                adaptive = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                                cv2.THRESH_BINARY, 11, 2)
                candidates.append(adaptive)
            except:
                pass
            
            # 4. Binary threshold (Otsu)
            try:
                _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
                candidates.append(binary)
            except:
                pass
            
            # 5. Morphological operations to enhance plate characters
            try:
                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
                morph = cv2.morphologyEx(gray, cv2.MORPH_CLOSE, kernel)
                candidates.append(morph)
            except:
                pass
            
            # Run OCR on all candidates with lower threshold
            for candidate in candidates:
                try:
                    results = reader.readtext(candidate, detail=1, paragraph=False,
                                            min_size=1, text_threshold=0.2,
                                            low_text=0.15, width_ths=0.4)
                    all_texts.extend([(r[1], r[2]) for r in results])
                except:
                    pass
    except Exception:
        pass

    if not all_texts:
        return {'detected': False, 'text': '', 'valid': False}

    # Very low confidence threshold to catch maximum plates
    valid_texts = [(text, conf) for text, conf in all_texts if conf > 0.15]
    
    if not valid_texts:
        return {'detected': False, 'text': '', 'valid': False}
    
    valid_texts.sort(key=lambda x: x[1], reverse=True)
    
    # Try top 10 results for better chance
    for best_text, confidence in valid_texts[:10]:
        text = best_text.upper().replace(' ', '').replace('-', '').replace('.', '').replace(',', '')
        text = text.replace('O', '0').replace('I', '1').replace('Z', '2').replace('S', '5')
        text = ''.join(c for c in text if c.isalnum())
        
        # Filter false positives
        is_false_positive = False
        for keyword in INVALID_PLATE_KEYWORDS:
            if keyword in text:
                is_false_positive = True
                break
        
        if is_false_positive:
            continue
        
        # Very lenient validation - minimum 3 characters
        if len(text) >= 3 and len(text) <= 15:
            has_letters = any(c.isalpha() for c in text)
            has_numbers = any(c.isdigit() for c in text)
            
            if has_letters and has_numbers:
                # Relaxed: at least 1 letter and 1 number
                letter_count = sum(1 for c in text if c.isalpha())
                number_count = sum(1 for c in text if c.isdigit())
                
                if letter_count >= 1 and number_count >= 1:
                    return {'detected': True, 'text': text, 'valid': True, 'confidence': confidence}
    
    return {'detected': False, 'text': '', 'valid': False}

print("‚úÖ All detection functions loaded!")
print("‚ö†Ô∏è  Red light detection: DISABLED (enable only for fixed camera setup)")
print("‚úÖ Number plate detection: SUPER LENIENT (detects maximum plates)")
print("üé≤ Random plate generation: Only when NO plate detected")
print("üéØ Using YOLOv8x: Best accuracy model")

## 6. Process Images & Detect Violations

In [None]:
print("\nüöÄ Starting violation detection...\n" + "="*70)
print(f"‚öôÔ∏è  Red Light Detection: {'ENABLED' if ENABLE_RED_LIGHT_DETECTION else 'DISABLED'}")
print("="*70)

image_files = glob.glob('input_images/*.jpg') + glob.glob('input_images/*.png') + glob.glob('input_images/*.jpeg')
all_results = []

if not image_files:
    print("‚ùå No images found!")
else:
    for idx, img_path in enumerate(image_files, 1):
        print(f"\n[{idx}/{len(image_files)}] Processing: {os.path.basename(img_path)}")
        
        frame = cv2.imread(img_path)
        if frame is None:
            print(f"  ‚ö†Ô∏è Failed to read image, skipping...")
            continue
        
        original_frame = frame.copy()
        frame = cv2.resize(frame, (1100, 700))
        
        red_light_on = is_red_light(frame, RedLight)
        
        # Only draw ROI polygons if red light detection is enabled
        if ENABLE_RED_LIGHT_DETECTION and red_light_on is not None:
            cv2.polylines(frame, [RedLight], True, [0, 0, 255], 2)
            cv2.polylines(frame, [GreenLight], True, [0, 255, 0], 2)
            cv2.polylines(frame, [ROI], True, [255, 0, 0], 3)
        
        # Lower confidence for better vehicle detection
        results = model.predict(frame, conf=0.2, verbose=False, iou=0.4)
        
        violations = {
            'red_light': [],
            'no_helmet': [],
            'triple_riding': [],
            'no_number_plate': []
        }
        
        all_persons = []
        all_vehicles = []
        
        for result in results:
            boxes = result.boxes.xyxy
            confs = result.boxes.conf
            classes = result.boxes.cls
            
            for box, conf, cls in zip(boxes, confs, classes):
                class_name = coco[int(cls)]
                x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
                
                if class_name == 'person':
                    all_persons.append([x1, y1, x2, y2])
                elif class_name in ['car', 'bus', 'truck', 'motorcycle', 'bicycle']:
                    all_vehicles.append({
                        'class': class_name,
                        'bbox': [x1, y1, x2, y2],
                        'conf': float(conf)
                    })
        
        violation_y = 40
        processed_motorcycles = set()
        
        for vehicle_idx, vehicle in enumerate(all_vehicles):
            v_class = vehicle['class']
            v_bbox = vehicle['bbox']
            x1, y1, x2, y2 = v_bbox
            
            cv2.rectangle(frame, (x1, y1), (x2, y2), [0, 255, 0], 2)
            cv2.putText(frame, v_class, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, [0, 255, 0], 2)
            
            # Red light violation - only check if detection is enabled
            if ENABLE_RED_LIGHT_DETECTION and red_light_on is True:
                in_roi = (cv2.pointPolygonTest(ROI, (x1, y1), False) >= 0 or
                         cv2.pointPolygonTest(ROI, (x2, y2), False) >= 0 or
                         cv2.pointPolygonTest(ROI, ((x1+x2)//2, (y1+y2)//2), False) >= 0 or
                         cv2.pointPolygonTest(ROI, (x1, y2), False) >= 0 or
                         cv2.pointPolygonTest(ROI, (x2, y1), False) >= 0)
                
                if in_roi:
                    violations['red_light'].append({'type': v_class, 'bbox': v_bbox})
                    cv2.rectangle(frame, (x1, y1), (x2, y2), [0, 0, 255], 4)
                    draw_text_bg(frame, "RED LIGHT VIOLATION", (10, violation_y), border=(0,0,255))
                    violation_y += 35
            
            # Motorcycle checks
            if v_class == 'motorcycle' and vehicle_idx not in processed_motorcycles:
                processed_motorcycles.add(vehicle_idx)
                rider_count, riders = count_riders_on_motorcycle(v_bbox, all_persons)
                
                if rider_count >= 3:
                    violations['triple_riding'].append({'riders': rider_count, 'bbox': v_bbox})
                    cv2.rectangle(frame, (x1, y1), (x2, y2), [255, 0, 255], 4)
                    draw_text_bg(frame, f"TRIPLE RIDING ({rider_count} riders)", (10, violation_y), border=(255,0,255))
                    violation_y += 35
                
                if riders:
                    for rider_idx, rider_box in enumerate(riders):
                        has_helmet = detect_helmet(frame, rider_box, v_bbox)
                        if not has_helmet:
                            violations['no_helmet'].append({
                                'bbox': v_bbox,
                                'rider': rider_idx+1,
                                'rider_box': rider_box
                            })
                            cv2.rectangle(frame, (x1, y1), (x2, y2), [0, 165, 255], 4)
                            rx1, ry1, rx2, ry2 = rider_box
                            cv2.rectangle(frame, (rx1, ry1), (rx2, ry2), [255, 255, 0], 3)
                            cv2.putText(frame, f"No Helmet", (rx1, ry1-5),
                                      cv2.FONT_HERSHEY_SIMPLEX, 0.5, [255, 255, 0], 2)
                            draw_text_bg(frame, f"NO HELMET - Rider {rider_idx+1}", (10, violation_y), border=(0,165,255))
                            violation_y += 35
            
            # Number plate detection - show only if detected or generate if not found
            if v_class in ['car', 'bus', 'truck', 'motorcycle']:
                plate_info = detect_number_plate(frame, v_bbox)
                
                if plate_info['detected'] and plate_info['text']:
                    # Real plate detected - show in GREEN
                    plate_text = plate_info['text']
                    confidence = plate_info.get('confidence', 0)
                    cv2.putText(frame, f"Plate: {plate_text} ({confidence:.2f})", (x1, y2+20),
                              cv2.FONT_HERSHEY_SIMPLEX, 0.6, [0, 255, 0], 2)
                else:
                    # No plate detected - mark as violation and generate random
                    plate_text = generate_random_plate()
                    violations['no_number_plate'].append({
                        'type': v_class, 
                        'bbox': v_bbox,
                        'generated_plate': plate_text
                    })
                    cv2.rectangle(frame, (x1, y1), (x2, y2), [255, 165, 0], 3)
                    cv2.putText(frame, f"No Plate (Gen: {plate_text})", (x1, y2+20),
                              cv2.FONT_HERSHEY_SIMPLEX, 0.5, [255, 165, 0], 2)
                    draw_text_bg(frame, f"NO PLATE - Generated: {plate_text}", (10, violation_y), 
                               border=(255,165,0))
                    violation_y += 35
        
        # Signal indicator - only if detection is enabled
        if ENABLE_RED_LIGHT_DETECTION and red_light_on is not None:
            signal_color = "RED" if red_light_on else "GREEN"
            signal_bg = (0, 0, 255) if red_light_on else (0, 255, 0)
            cv2.rectangle(frame, (frame.shape[1]-200, 10), (frame.shape[1]-10, 50), signal_bg, -1)
            cv2.putText(frame, f"Signal: {signal_color}", (frame.shape[1]-190, 35),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        
        output_path = f'output_images/{os.path.basename(img_path)}'
        cv2.imwrite(output_path, frame)
        
        total = sum(len(v) for v in violations.values())
        
        result = {
            'image': img_path,
            'output': output_path,
            'red_light_on': 'DISABLED' if not ENABLE_RED_LIGHT_DETECTION else (red_light_on if red_light_on is not None else 'N/A'),
            'total_violations': total,
            'violations': violations,
            'vehicle_count': len(all_vehicles),
            'person_count': len(all_persons),
            'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
        all_results.append(result)
        
        if not ENABLE_RED_LIGHT_DETECTION:
            print(f"  Signal: ‚ö†Ô∏è  DISABLED")
        elif red_light_on is None:
            print(f"  Signal: ‚ö™ NO TRAFFIC LIGHT")
        else:
            print(f"  Signal: {'üî¥ RED' if red_light_on else 'üü¢ GREEN'}")
        print(f"  Detected: {len(all_vehicles)} vehicles, {len(all_persons)} persons")
        print(f"  Total Violations: {total}")
        if violations['red_light']:
            print(f"    üî¥ Red Light: {len(violations['red_light'])}")
        if violations['no_helmet']:
            print(f"    ü™ñ No Helmet: {len(violations['no_helmet'])}")
        if violations['triple_riding']:
            print(f"    üë• Triple Riding: {len(violations['triple_riding'])}")
        if violations['no_number_plate']:
            print(f"    üö´ No Plate (Generated): {len(violations['no_number_plate'])}")
    
    with open('output_images/violations_report.json', 'w') as f:
        json.dump(all_results, f, indent=2)
    
    print("\n" + "="*70)
    print("\n‚úÖ Processing complete!")
    print(f"üìä Total images processed: {len(image_files)}")
    print(f"üö® Total violations found: {sum(r['total_violations'] for r in all_results)}")
    print(f"üìÅ Results saved in 'output_images/' folder")

## 7. Display Results

In [None]:
print("\nüì∏ RESULTS GALLERY\n" + "="*70)

for i, result in enumerate(all_results, 1):
    print(f"\n{'='*70}")
    print(f"Image {i}: {os.path.basename(result['image'])}")
    print(f"{'='*70}")
    
    display(HTML(f'''
    <div style="display: flex; gap: 10px; margin: 20px 0;">
        <div style="flex: 1;">
            <h4>Original</h4>
            <img src="{result['image']}" style="width: 100%; border: 2px solid #ddd;">
        </div>
        <div style="flex: 1;">
            <h4>Detected ({result['total_violations']} violations)</h4>
            <img src="{result['output']}" style="width: 100%; border: 3px solid {'red' if result['total_violations'] > 0 else 'green'};">
        </div>
    </div>
    '''))
    
    v = result['violations']
    if result['total_violations'] > 0:
        print("‚ö†Ô∏è Violations Found:")
        if v['red_light']: print(f"  üî¥ Red Light: {len(v['red_light'])}")
        if v['no_helmet']: print(f"  ü™ñ No Helmet: {len(v['no_helmet'])}")
        if v['triple_riding']: print(f"  üë• Triple Riding: {len(v['triple_riding'])}")
        if v['no_number_plate']: print(f"  üö´ No Number Plate: {len(v['no_number_plate'])}")
    else:
        print("‚úÖ No violations detected")

## 8. Download Results

In [None]:
import shutil

# Check if running in Colab
try:
    from google.colab import files
    IN_COLAB = True
except:
    IN_COLAB = False

print("üì¶ Creating download package...")
shutil.make_archive('complete_violation_results', 'zip', 'output_images')
print("‚úÖ Package ready!\n")

if IN_COLAB:
    print("üì• Downloading results...")
    files.download('complete_violation_results.zip')
    files.download('output_images/violations_report.json')
    print("‚úÖ Downloads started! Check your browser.")
else:
    print("üìÅ Results saved as 'complete_violation_results.zip'")
    print("üìÑ JSON report: 'output_images/violations_report.json'")