In [1]:
!pip install opencv-python numpy ultralytics transformers pillow matplotlib seaborn pandas torch reportlab

Collecting ultralytics
  Downloading ultralytics-8.3.111-py3-none-any.whl.metadata (37 kB)
Collecting reportlab
  Downloading reportlab-4.4.0-py3-none-any.whl.metadata (1.8 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.14-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-

In [None]:
import cv2
import numpy as np
from ultralytics import YOLO
from transformers import pipeline
import re
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import torch
import gc
from reportlab.lib.pagesizes import letter
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image as ReportLabImage
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import inch
import json
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from datetime import datetime

try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')
try:
    nltk.data.find('corpora/stopwords')
except LookupError:
    nltk.download('stopwords')

class TrafficAnalyzer:
    def __init__(self, model_path, video_path):
        # Clear GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
        # Verify input files
        self.model_path = model_path
        self.video_path = video_path
        if not os.path.exists(model_path) or not os.path.exists(video_path):
            print("Error: Model or video file not found.")
            exit()
            
        # Initialize variables
        self.selected_frame = None
        self.peak_frame = None
        self.vehicle_counts = {}
        self.track_id_to_class = {}
        self.frame_vehicle_counts = []
        self.max_vehicles = 0
        self.max_frame_index = 0
        self.emergency_alerts = []
        self.track_positions = defaultdict(list)
        self.average_speeds = {}
        self.congestion_indices = []
        
        # Data processing variables
        self.processed = False
        self.context = ""
        self.average_counts = {}
        self.total_frames = 0
        self.fps = 0
        self.max_time_sec = 0
        self.middle_index = 0
        
        # Track history for unique vehicle counting
        self.track_history = defaultdict(lambda: {
            'first_seen': float('inf'),
            'last_seen': -1,
            'vehicle_type': None,
            'positions': []
        })
        
        # Load YOLO model
        try:
            self.model = YOLO(model_path)
            print("YOLO model loaded successfully.")
        except Exception as e:
            print(f"Error loading YOLO model: {e}")
            exit()
            
        # Initialize NLP components
        self.load_nlp_models()
    
    def load_nlp_models(self):
        try:
            self.text_generator = pipeline("text2text-generation", model="google/flan-t5-base", 
                                          device=0 if torch.cuda.is_available() else -1)
            print("Flan-T5-Base loaded for text generation.")
        except Exception as e:
            print(f"Warning: Could not load Flan-T5-Base: {e}")
            self.text_generator = None
            
        try:
            self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn", 
                                      device=0 if torch.cuda.is_available() else -1)
            print("BART-Large-CNN loaded for summarization.")
        except Exception as e:
            print(f"Warning: Could not load BART-Large-CNN: {e}")
            self.summarizer = None
    
    def process_video(self):
        # Video setup
        cap = cv2.VideoCapture(self.video_path)
        if not cap.isOpened():
            print(f"Error: Could not open video at {self.video_path}")
            return False
            
        self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if self.total_frames == 0:
            print("Error: Video has no frames.")
            return False
            
        self.middle_index = self.total_frames // 2
        self.fps = cap.get(cv2.CAP_PROP_FPS)
        frame_index = 0
        
        # Process each frame
        print("Processing video...")
        while cap.isOpened():
            success, frame = cap.read()
            if not success:
                break
                
            frame = cv2.resize(frame, (700, 500))
            results = self.model.track(frame, persist=True)
            
            if results[0].boxes is None or len(results[0].boxes) == 0:
                self.frame_vehicle_counts.append({})
                frame_index += 1
                continue
                
            boxes = results[0].boxes.xyxy.cpu().numpy()
            confidences = results[0].boxes.conf.cpu().numpy()
            classes = results[0].boxes.cls.cpu().numpy().astype(int)
            track_ids = results[0].boxes.id.cpu().numpy() if results[0].boxes.id is not None else []
            
            current_frame_counts = defaultdict(int)
            current_time = frame_index / self.fps
            
            for conf, cls, track_id, box in zip(confidences, classes, track_ids, boxes):
                if conf < 0.5:
                    continue
                label = results[0].names[cls]
                
                # Only count vehicles (not traffic lights, people, etc.)
                if label in ['car', 'truck', 'bus', 'motorcycle', 'bicycle', 'ambulance', 'police car']:
                    # Update track history for unique vehicle counting
                    if track_id not in self.track_history:
                        self.track_history[track_id]['vehicle_type'] = label
                        self.vehicle_counts[label] = self.vehicle_counts.get(label, 0) + 1
                    
                    self.track_history[track_id]['first_seen'] = min(self.track_history[track_id]['first_seen'], current_time)
                    self.track_history[track_id]['last_seen'] = max(self.track_history[track_id]['last_seen'], current_time)
                    
                    # Update track ID to class mapping
                    self.track_id_to_class[track_id] = label
                    
                    # Count for current frame
                    current_frame_counts[label] += 1
                    
                    # Store position for speed calculation
                    center_x = (box[0] + box[2]) / 2
                    center_y = (box[1] + box[3]) / 2
                    self.track_positions[track_id].append((frame_index, center_x, center_y))
            
            self.frame_vehicle_counts.append(dict(current_frame_counts))
            
            # Check for emergency vehicles
            if current_frame_counts.get("ambulance", 0) > 0:
                alert = f"Ambulance activity at {frame_index/self.fps:.2f}s: {current_frame_counts['ambulance']} ambulance(s)"
                self.emergency_alerts.append(alert)
            if current_frame_counts.get("police car", 0) > 0:
                alert = f"Police car activity at {frame_index/self.fps:.2f}s: {current_frame_counts['police car']} police car(s)"
                self.emergency_alerts.append(alert)
            
            # Calculate congestion
            total_in_frame = sum(current_frame_counts.values())
            congestion_index = total_in_frame / 5.0  # Normalize to reasonable scale
            self.congestion_indices.append(congestion_index)
            
            # Track peak traffic
            if total_in_frame > self.max_vehicles:
                self.max_vehicles = total_in_frame
                self.max_frame_index = frame_index
                self.peak_frame = results[0].plot()
            
            # Save middle frame
            if frame_index == self.middle_index:
                self.selected_frame = results[0].plot()
            
            frame_index += 1
        
        cap.release()
        cv2.destroyAllWindows()
        print(f"Video processing complete: {self.total_frames} frames processed.")
        
        # Calculate average speeds
        print("Calculating average speeds...")
        self.calculate_average_speeds()
        
        # Save annotated frames
        self.save_frames()
        
        # Calculate statistics and create context
        self.calculate_statistics()
        
        # Save data to JSON for accurate retrieval later
        with open('traffic_data.json', 'w') as f:
            json.dump({
                'vehicle_counts': self.vehicle_counts,
                'average_counts': {k: float(v) for k, v in self.average_counts.items()},
                'average_speeds': {k: float(v) for k, v in self.average_speeds.items()},
                'max_vehicles': self.max_vehicles,
                'max_time_sec': self.max_time_sec,
                'total_frames': self.total_frames,
                'fps': float(self.fps),
                'emergency_alerts': self.emergency_alerts,
                'middle_index': self.middle_index,
                'middle_frame_counts': self.frame_vehicle_counts[self.middle_index] if self.middle_index < len(self.frame_vehicle_counts) else {}
            }, f)
        
        # Export to CSV
        self.export_to_csv()
        
        # Generate heatmap
        self.generate_heatmap()
        
        # Generate report
        self.generate_report()
        
        self.processed = True
        return True
        
    def calculate_average_speeds(self):
        for track_id, positions in self.track_positions.items():
            label = self.track_id_to_class[track_id]
            if len(positions) < 2:
                continue
            total_speed = 0
            count = 0
            for i in range(1, len(positions)):
                frame_diff = positions[i][0] - positions[i-1][0]
                if frame_diff == 0:
                    continue
                dx = positions[i][1] - positions[i-1][1]
                dy = positions[i][2] - positions[i-1][2]
                distance = np.sqrt(dx**2 + dy**2)
                time = frame_diff / self.fps
                speed = distance / time
                total_speed += speed
                count += 1
            if count > 0:
                if label not in self.average_speeds:
                    self.average_speeds[label] = 0
                    self.average_speeds[f"{label}_count"] = 0
                self.average_speeds[label] += total_speed / count
                self.average_speeds[f"{label}_count"] += 1
        
        for label in self.vehicle_counts.keys():
            count_key = f"{label}_count"
            if count_key in self.average_speeds:
                self.average_speeds[label] = self.average_speeds[label] / self.average_speeds[count_key]
                del self.average_speeds[count_key]
    
    def save_frames(self):
        try:
            if self.selected_frame is not None:
                cv2.imwrite("middle_frame.jpg", self.selected_frame)
                print("Middle frame saved as 'middle_frame.jpg'.")
            if self.peak_frame is not None:
                cv2.imwrite("peak_frame.jpg", self.peak_frame)
                print("Peak frame saved as 'peak_frame.jpg'.")
        except Exception as e:
            print(f"Error saving frames: {e}")
    
    def calculate_statistics(self):
        # Calculate averages and max time
        for vtype in self.vehicle_counts.keys():
            total = sum(frame_counts.get(vtype, 0) for frame_counts in self.frame_vehicle_counts)
            self.average_counts[vtype] = total / len(self.frame_vehicle_counts) if self.frame_vehicle_counts else 0
        
        self.max_time_sec = self.max_frame_index / self.fps if self.fps > 0 else 0
        average_congestion = np.mean(self.congestion_indices) if self.congestion_indices else 0
        
        # First section, middle section, last section vehicle counts
        first_section = sum(sum(fc.values()) for fc in self.frame_vehicle_counts[:len(self.frame_vehicle_counts)//4])
        middle_section = sum(sum(fc.values()) for fc in self.frame_vehicle_counts[len(self.frame_vehicle_counts)//4:3*len(self.frame_vehicle_counts)//4])
        last_section = sum(sum(fc.values()) for fc in self.frame_vehicle_counts[3*len(self.frame_vehicle_counts)//4:])
        
        # Create context string
        self.context = (
            f"- Video duration: {(self.total_frames / self.fps):.2f} seconds\n"
            f"- Unique vehicles: {', '.join([f'{v}: {c}' for v, c in self.vehicle_counts.items()])} (total: {sum(self.vehicle_counts.values())})\n"
            f"- Average vehicles per frame: {', '.join([f'{v}: {c:.2f}' for v, c in self.average_counts.items()])}\n"
            f"- Average speeds (pixels/s): {', '.join([f'{v}: {s:.2f}' for v, s in self.average_speeds.items()])}\n"
            f"- Peak traffic: {self.max_vehicles} vehicles at {self.max_time_sec:.2f} seconds\n"
            f"- Middle frame (at {(self.middle_index / self.fps):.2f} seconds): {', '.join([f'{v}: {c}' for v, c in self.frame_vehicle_counts[self.middle_index].items()]) if self.middle_index < len(self.frame_vehicle_counts) else 'no vehicles'}\n"
            f"- Average congestion index: {average_congestion:.2f} (0=low, 1=moderate, >2=high)\n"
            f"- Emergency alerts: {'; '.join(self.emergency_alerts) if self.emergency_alerts else 'None'}\n"
            f"- Temporal trends: First 25%: {first_section} vehicles, "
            f"Middle 50%: {middle_section} vehicles, "
            f"Last 25%: {last_section} vehicles"
        )
        print("\n--- Context Generated ---\n")
        print(self.context)
    
    def export_to_csv(self):
        print("Exporting data to CSV...")
        times = [i / self.fps for i in range(len(self.frame_vehicle_counts))]
        total_vehicles_per_frame = [sum(frame_counts.values()) for frame_counts in self.frame_vehicle_counts]
        
        csv_data = {
            "Frame": list(range(len(self.frame_vehicle_counts))),
            "Time (s)": times,
            "Total Vehicles": total_vehicles_per_frame,
            "Congestion Index": self.congestion_indices
        }
        
        for vtype in self.vehicle_counts.keys():
            csv_data[vtype] = [frame_counts.get(vtype, 0) for frame_counts in self.frame_vehicle_counts]
        
        df = pd.DataFrame(csv_data)
        try:
            df.to_csv("traffic_data.csv", index=False)
            print("Traffic data exported to 'traffic_data.csv'.")
        except Exception as e:
            print(f"Error exporting CSV: {e}")
    
    def generate_heatmap(self):
        print("Generating traffic density heatmap...")
        times = [i / self.fps for i in range(len(self.frame_vehicle_counts))]
        total_vehicles_per_frame = [sum(frame_counts.values()) for frame_counts in self.frame_vehicle_counts]
        
        plt.figure(figsize=(10, 4))
        sns.heatmap([total_vehicles_per_frame], cmap="YlOrRd", xticklabels=50, cbar_kws={'label': 'Vehicle Count'})
        plt.xlabel("Time (seconds)")
        plt.ylabel("Density")
        plt.title("Traffic Density Heatmap")
        plt.xticks(ticks=np.linspace(0, len(times)-1, 5), labels=[f"{t:.1f}" for t in np.linspace(0, max(times), 5)])
        try:
            plt.savefig("heatmap.png")
            plt.close()
            print("Heatmap saved as 'heatmap.png'.")
        except Exception as e:
            print(f"Error saving heatmap: {e}")
    
    def generate_creative_text(self, prompt):
        if self.text_generator:
            try:
                result = self.text_generator(prompt, max_new_tokens=300)
                return result[0]['generated_text'].strip()
            except Exception as e:
                print(f"Error generating text: {e}")
                return "Could not generate creative text."
        else:
            return "Creative text generation not available (NLP models not loaded)."
    
    def enhance_text(self, text):
        if self.summarizer:
            try:
                enhanced = self.summarizer(text, max_length=len(text.split()) + 50, min_length=len(text.split()), 
                                          do_sample=True)[0]['summary_text']
                return enhanced
            except Exception as e:
                print(f"Error enhancing text: {e}")
                return text
        else:
            return text
    
    def generate_report(self):
        print("Generating text report...")
        
        # Generate traffic flow description
        flow_prompt = (
            f"Context:\n{self.context}\n"
            "Instruction: Generate a detailed traffic flow analysis starting with 'The traffic flow analysis reveals...'. "
            "Include patterns of congestion, peak times, and vehicle distribution."
        )
        flow_text = self.generate_creative_text(flow_prompt)
        
        # Generate vehicle behavior description
        behavior_prompt = (
            f"Context:\n{self.context}\n"
            "Instruction: Generate a detailed description of vehicle behavior starting with 'Vehicle behavior analysis shows...'. "
            "Include speed patterns, types of vehicles, and any unusual events."
        )
        behavior_text = self.generate_creative_text(behavior_prompt)
        
        # Generate recommendations
        recommendations_prompt = (
            f"Context:\n{self.context}\n"
            "Instruction: Generate 3-4 traffic management recommendations based on this data, "
            "starting with 'Based on the analysis, we recommend...'"
        )
        recommendations_text = self.generate_creative_text(recommendations_prompt)
        
        # Combine all sections
        report_text = (
            f"# Traffic Analysis Report\n\n"
            f"## Overview\n\n"
            f"This report analyzes traffic patterns observed in a video of duration {(self.total_frames / self.fps):.2f} seconds. "
            f"A total of {sum(self.vehicle_counts.values())} unique vehicles were detected, "
            f"with peak traffic occurring at {self.max_time_sec:.2f} seconds.\n\n"
            f"## Traffic Flow Analysis\n\n{flow_text}\n\n"
            f"## Vehicle Behavior\n\n{behavior_text}\n\n"
            f"## Recommendations\n\n{recommendations_text}\n\n"
            f"## Statistical Summary\n\n"
            f"- Total unique vehicles: {sum(self.vehicle_counts.values())}\n"
            f"- Vehicle types: {', '.join([f'{v}: {c}' for v, c in self.vehicle_counts.items()])}\n"
            f"- Peak congestion: {self.max_vehicles} vehicles at {self.max_time_sec:.2f} seconds\n"
            f"- Average congestion index: {np.mean(self.congestion_indices):.2f}\n"
        )
        
        # Enhance the report with more creative language
        enhanced_report = self.enhance_text(report_text)
        
        print("\n--- Generated Report ---\n")
        print(enhanced_report)
        
        # Generate PDF report
        print("Generating PDF report...")
        try:
            pdf = SimpleDocTemplate("report.pdf", pagesize=letter)
            styles = getSampleStyleSheet()
            normal_style = ParagraphStyle(name='NormalWrap', parent=styles['Normal'], wordWrap='CJK')
            heading_style = styles['Heading1']
            subheading_style = styles['Heading2']
            elements = []
            
            elements.append(Paragraph("Vehicle Detection Report", heading_style))
            elements.append(Spacer(1, 0.2 * inch))
            
            # Add current date
            current_date = datetime.now().strftime("%B %d, %Y")
            elements.append(Paragraph(f"Generated on: {current_date}", normal_style))
            elements.append(Spacer(1, 0.2 * inch))
            
            # Process markdown-like sections in the report
            sections = enhanced_report.split('##')
            for i, section in enumerate(sections):
                if i == 0:  # First part (title)
                    continue
                
                lines = section.strip().split('\n')
                section_title = lines[0].strip()
                section_content = '\n'.join(lines[1:]).strip()
                
                elements.append(Paragraph(section_title, subheading_style))
                elements.append(Spacer(1, 0.1 * inch))
                
                for paragraph in section_content.split('\n\n'):
                    if paragraph.strip():
                        elements.append(Paragraph(paragraph.strip(), normal_style))
                        elements.append(Spacer(1, 0.1 * inch))
            
            # Add vehicle statistics table
            elements.append(Paragraph("Vehicle Statistics", subheading_style))
            elements.append(Spacer(1, 0.1 * inch))
            table_data = [['Vehicle Type', 'Unique Count', 'Avg/Frame', 'Middle Frame', 'Avg Speed (px/s)']]
            
            for vtype in self.vehicle_counts.keys():
                table_data.append([
                    vtype, 
                    str(self.vehicle_counts.get(vtype, 0)), 
                    f"{self.average_counts.get(vtype, 0):.2f}", 
                    str(self.frame_vehicle_counts[self.middle_index].get(vtype, 0) if self.middle_index < len(self.frame_vehicle_counts) else 0), 
                    f"{self.average_speeds.get(vtype, 0):.2f}"
                ])
            
            table = Table(table_data)
            table.setStyle(TableStyle([
                ('BACKGROUND', (0, 0), (-1, 0), colors.grey),
                ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
                ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, 0), 10),
                ('BOTTOMPADDING', (0, 0), (-1, 0), 12),
                ('BACKGROUND', (0, 1), (-1, -1), colors.beige),
                ('GRID', (0, 0), (-1, -1), 1, colors.black)
            ]))
            elements.append(table)
            elements.append(Spacer(1, 0.2 * inch))
            
            # Add visualizations
            elements.append(Paragraph("Visualizations", subheading_style))
            elements.append(Spacer(1, 0.1 * inch))
            image_paths = [("heatmap.png", "Traffic Density Heatmap"), ("middle_frame.jpg", "Middle Frame"), ("peak_frame.jpg", "Peak Traffic Frame")]
            
            for path, title in image_paths:
                if os.path.exists(path):
                    img = ReportLabImage(path, width=5*inch, height=3*inch)
                    elements.append(Paragraph(title, normal_style))
                    elements.append(img)
                    elements.append(Spacer(1, 0.1 * inch))
                else:
                    elements.append(Paragraph(f"{title} not available.", normal_style))
            
            pdf.build(elements)
            print("PDF report generated as 'report.pdf'.")
        except Exception as e:
            print(f"Error generating PDF: {e}")
    
    def get_unique_vehicles_in_timerange(self, start_time, end_time):
        """
        Count unique vehicles that appeared between start_time and end_time.
        """
        unique_vehicles = defaultdict(int)
        
        for track_id, data in self.track_history.items():
            # Check if the vehicle was present during the time range
            if data['first_seen'] <= end_time and data['last_seen'] >= start_time:
                vehicle_type = data['vehicle_type']
                if vehicle_type:
                    unique_vehicles[vehicle_type] += 1
        
        return dict(unique_vehicles)
    
    def answer_question(self, question):
        """
        Rules-based factual question answering from stored data, with creative descriptions
        where appropriate.
        """
        if not self.processed:
            return "Please process the video first before asking questions."
            
        if not question or not isinstance(question, str):
            return "Please ask a valid question."
            
        # Load saved data to ensure accuracy
        try:
            with open('traffic_data.json', 'r') as f:
                data = json.load(f)
        except:
            # Fall back to in-memory data if file not found
            data = {
                'vehicle_counts': self.vehicle_counts,
                'average_counts': self.average_counts,
                'average_speeds': self.average_speeds,
                'max_vehicles': self.max_vehicles,
                'max_time_sec': self.max_time_sec,
                'total_frames': self.total_frames,
                'fps': self.fps,
                'emergency_alerts': self.emergency_alerts,
                'middle_index': self.middle_index,
                'middle_frame_counts': self.frame_vehicle_counts[self.middle_index] if self.middle_index < len(self.frame_vehicle_counts) else {}
            }
            
        # Normalize question
        question = question.lower().strip()
        tokens = word_tokenize(question)
        stop_words = set(stopwords.words('english'))
        filtered_tokens = [w for w in tokens if w.isalnum() and w not in stop_words]
        
        # Time range handling - IMPROVED for unique vehicle counting
        time_range_match = re.search(r"from (\d+(?:\.\d+)?)\s*(?:to|s(?:econd)?s?\s*to)\s*(\d+(?:\.\d+)?)\s*s(?:econd)?s?", question)
        time_match = re.search(r"at (\d+(?:\.\d+)?) seconds?", question)
        
        if time_range_match:
            start_time = float(time_range_match.group(1))
            end_time = float(time_range_match.group(2))
            
            if start_time >= end_time or end_time > self.total_frames / self.fps:
                return f"Invalid time range: {start_time}s to {end_time}s. Video duration is {self.total_frames/self.fps:.2f}s."
                
            # Get unique vehicles in the time range
            unique_vehicles = self.get_unique_vehicles_in_timerange(start_time, end_time)
            total_unique = sum(unique_vehicles.values())
            
            # Generate response
            response = f"From {start_time}s to {end_time}s, {total_unique} unique vehicles passed: "
            response += ", ".join([f"{count} {vtype}(s)" for vtype, count in unique_vehicles.items()])
            return response
            
        elif time_match:
            time_sec = float(time_match.group(1))
            
            if time_sec > self.total_frames / self.fps:
                return f"Time {time_sec}s exceeds video duration ({self.total_frames/self.fps:.2f}s)."
                
            frame_index = min(int(time_sec * self.fps), self.total_frames - 1)
            frame_counts = self.frame_vehicle_counts[frame_index]
            
            if not frame_counts:
                return f"At {time_sec}s: No vehicles detected."
                
            response = f"At {time_sec}s: "
            response += ", ".join([f"{count} {vtype}(s)" for vtype, count in frame_counts.items()])
            return response
        
        # Total vehicle questions
        if any(w in filtered_tokens for w in ['total', 'all']) and any(w in filtered_tokens for w in ['vehicle', 'vehicles', 'cars', 'detected']):
            return f"Total unique vehicles detected: {sum(data['vehicle_counts'].values())}."
        
        # Vehicle type questions
        for vehicle_type in data['vehicle_counts'].keys():
            if vehicle_type.lower() in question:
                count = data['vehicle_counts'].get(vehicle_type, 0)
                avg = data['average_counts'].get(vehicle_type, 0)
                speed = data['average_speeds'].get(vehicle_type, 0)
                response = f"{count} {vehicle_type}(s) detected in total. "
                response += f"Average of {avg:.2f} per frame. "
                
                if speed > 0:
                    response += f"Average speed: {speed:.2f} pixels/second."
                return response
                
        # Peak traffic questions
        if any(w in filtered_tokens for w in ['peak', 'busiest', 'maximum']):
            if any(w in filtered_tokens for w in ['time', 'when']):
                return f"Peak traffic occurred at {data['max_time_sec']:.2f} seconds with {data['max_vehicles']} vehicles in frame."
            else:
                return f"Peak traffic was {data['max_vehicles']} vehicles at {data['max_time_sec']:.2f} seconds."
                
        # Middle of video questions
        if any(w in filtered_tokens for w in ['middle', 'mid']):
            middle_time = data['middle_index'] / data['fps']
            middle_counts = data.get('middle_frame_counts', {})
            
            if not middle_counts:
                return f"At the middle of the video ({middle_time:.2f}s): No vehicles detected."
                
            response = f"At the middle of the video ({middle_time:.2f}s): "
            response += ", ".join([f"{count} {vtype}(s)" for vtype, count in middle_counts.items()])
            return response
            
        # Duration question
        if any(w in filtered_tokens for w in ['duration', 'long', 'length']):
            return f"Video duration: {self.total_frames/self.fps:.2f} seconds."
            
        # Types of vehicles
        if any(w in filtered_tokens for w in ['types', 'kind', 'kinds']) and any(w in filtered_tokens for w in ['vehicle', 'vehicles']):
            return f"Types of vehicles detected: {', '.join(data['vehicle_counts'].keys())}."
            
        # Emergency vehicles
        if any(w in filtered_tokens for w in ['emergency', 'ambulance', 'police']):
            if data['emergency_alerts']:
                return f"Emergency vehicle activity: {'; '.join(data['emergency_alerts'])}"
            else:
                return "No emergency vehicles were detected in the video."
                
        # Descriptive questions
        if any(w in filtered_tokens for w in ['describe', 'description', 'summarize', 'summary', 'explain']):
            # Use language model for creative descriptions
            description_prompt = (
                f"Context:\n{self.context}\n"
                f"Instruction: Describe the traffic scene based on this data in detail. "
                f"Include information about vehicle types, traffic flow, congestion levels, and any notable events."
            )
            description = self.generate_creative_text(description_prompt)
            return self.enhance_text(description)
            
        # Fall back to context info
        return f"Here's what I know about the video:\n{self.context}"
        
    def run_chat_loop(self):
        """Run an interactive chat loop to answer questions about the video."""
        print("\nVideo processing complete. Ask any question about the video (type 'exit' to quit).")
        
        # Generate suggestions based on detected vehicles
        suggestions = [
            "How many vehicles were detected in total?",
            "What types of vehicles were detected?",
            "What was the peak traffic time?",
            "How many unique vehicles passed from 2 to 5 seconds?",
            "Describe the traffic in the video.",
            "What was happening in the middle of the video?"
        ]
        
        # Add vehicle-specific suggestions
        for vehicle_type in self.vehicle_counts.keys():
            suggestions.append(f"How many {vehicle_type}s were detected?")
            
        # Add emergency vehicle suggestions if applicable
        if self.emergency_alerts:
            suggestions.append("Were there any emergency vehicles?")
            
        print("Suggested questions:")
        for i, suggestion in enumerate(suggestions[:8], 1):  # Limit to 8 suggestions
            print(f"{i}. {suggestion}")
            
        while True:
            try:
                user_query = input("\nAsk a question (or type 'exit' to quit): ")
                if user_query.lower().strip() == "exit":
                    print("Goodbye.")
                    break
                    
                response = self.answer_question(user_query)
                print("\nAnswer:")
                print(response)
                
            except KeyboardInterrupt:
                print("\nChat interrupted. Type 'exit' to quit or continue.")
            except Exception as e:
                print(f"Error in chat loop: {e}")

def main():
    model_path = "/kaggle/input/testing/yolov8x.pt"  # Update with your model path
    video_path = "/kaggle/input/testing/test_30_Multi - Made with Clipchamp.mp4"  # Update with your video path
    
    analyzer = TrafficAnalyzer(model_path, video_path)
    if analyzer.process_video():
        analyzer.run_chat_loop()
    else:
        print("Failed to process video.")
    
    # Cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

if __name__ == "__main__":
    main()


YOLO model loaded successfully.


config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Device set to use cuda:0


Flan-T5-Base loaded for text generation.


config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Device set to use cuda:0


BART-Large-CNN loaded for summarization.
Processing video...
[31m[1mrequirements:[0m Ultralytics requirement ['lap>=0.5.12'] not found, attempting AutoUpdate...
Collecting lap>=0.5.12
  Downloading lap-0.5.12-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.2 kB)
Downloading lap-0.5.12-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 27.0 MB/s eta 0:00:00
Installing collected packages: lap
Successfully installed lap-0.5.12

[31m[1mrequirements:[0m AutoUpdate success ✅ 3.6s, installed 1 package: ['lap>=0.5.12']
[31m[1mrequirements:[0m ⚠️ [1mRestart runtime or rerun command for updates to take effect[0m


0: 480x640 9 cars, 68.0ms
Speed: 3.8ms preprocess, 68.0ms inference, 257.0ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 8 cars, 59.8ms
Speed: 2.2ms preprocess, 59.8ms inference, 1.4ms pos


Ask a question (or type 'exit' to quit):  How many vehicles were detected in total?



Answer:
Total unique vehicles detected: 64.



Ask a question (or type 'exit' to quit):  What types of vehicles were detected?



Answer:
Types of vehicles detected: car, truck, bus.



Ask a question (or type 'exit' to quit):  What was the peak traffic time?



Answer:
Peak traffic occurred at 27.73 seconds with 12 vehicles in frame.



Ask a question (or type 'exit' to quit):  How many unique vehicles passed from 2 to 5 seconds?



Answer:
From 2.0s to 5.0s, 16 unique vehicles passed: 15 car(s), 1 truck(s)



Ask a question (or type 'exit' to quit):  Describe the traffic in the video.


Your max_length is set to 58, but your input_length is only 13. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=6)



Answer:
The video shows 12 vehicles at 27.73 seconds.



Ask a question (or type 'exit' to quit):  How many cars were detected?



Answer:
51 car(s) detected in total. Average of 6.29 per frame. Average speed: 87.65 pixels/second.



Ask a question (or type 'exit' to quit):  How many trucks were detected?



Answer:
9 truck(s) detected in total. Average of 0.82 per frame. Average speed: 111.03 pixels/second.
