In [1]:
!pip install opencv-python numpy ultralytics transformers pillow matplotlib seaborn pandas torch reportlab

Collecting ultralytics
  Downloading ultralytics-8.3.111-py3-none-any.whl.metadata (37 kB)
Collecting reportlab
  Downloading reportlab-4.4.0-py3-none-any.whl.metadata (1.8 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.14-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-

In [None]:
import cv2
import numpy as np
from ultralytics import YOLO
from transformers import pipeline, logging
import re
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import torch
import gc
from reportlab.lib.pagesizes import letter, A4
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image as ReportLabImage, PageBreak, Flowable, ListFlowable, ListItem
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import inch
from reportlab.lib.enums import TA_JUSTIFY, TA_LEFT, TA_CENTER, TA_RIGHT
import json
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from datetime import datetime
import warnings
import matplotlib.dates as mdates
from matplotlib.ticker import MaxNLocator
import io

# Suppress warnings
warnings.filterwarnings("ignore")
logging.set_verbosity_error()

# Disable progress bars
import datasets
datasets.disable_progress_bar()

try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('corpora/stopwords')
except LookupError:
    nltk.download('stopwords', quiet=True)

# Custom flowable for horizontal line
class HorizontalLine(Flowable):
    def __init__(self, width, thickness=1, color=colors.black):
        Flowable.__init__(self)
        self.width = width
        self.thickness = thickness
        self.color = color

    def draw(self):
        self.canv.setStrokeColor(self.color)
        self.canv.setLineWidth(self.thickness)
        self.canv.line(0, 0, self.width, 0)

class TrafficAnalyzer:
    def __init__(self, model_path, video_path):
        # Clear GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
        # Verify input files
        self.model_path = model_path
        self.video_path = video_path
        if not os.path.exists(model_path) or not os.path.exists(video_path):
            print("Error: Model or video file not found.")
            exit()
            
        # Initialize variables
        self.selected_frame = None
        self.peak_frame = None
        self.vehicle_counts = {}
        self.track_id_to_class = {}
        self.frame_vehicle_counts = []
        self.max_vehicles = 0
        self.max_frame_index = 0
        self.emergency_alerts = []
        self.track_positions = defaultdict(list)
        self.average_speeds = {}
        self.congestion_indices = []
        
        # Data processing variables
        self.processed = False
        self.context = ""
        self.average_counts = {}
        self.total_frames = 0
        self.fps = 0
        self.max_time_sec = 0
        self.middle_index = 0
        
        # Track history for unique vehicle counting
        self.track_history = defaultdict(lambda: {
            'first_seen': float('inf'),
            'last_seen': -1,
            'vehicle_type': None,
            'positions': []
        })
        
        # Load YOLO model
        try:
            self.model = YOLO(model_path)
            print("YOLO model loaded successfully.")
        except Exception as e:
            print(f"Error loading YOLO model: {e}")
            exit()
            
        # Initialize NLP components
        self.load_nlp_models()
    
    def load_nlp_models(self):
        try:
            self.text_generator = pipeline(
                "text2text-generation", 
                model="google/flan-t5-base", 
                device=0 if torch.cuda.is_available() else -1,
                model_kwargs={"cache_dir": "./model_cache"},
                truncation=True
            )
            print("Flan-T5-Base loaded for text generation.")
        except Exception as e:
            print(f"Warning: Could not load Flan-T5-Base: {e}")
            self.text_generator = None
            
        try:
            # Configure summarizer with appropriate parameters to avoid warnings
            self.summarizer = pipeline(
                "summarization", 
                model="facebook/bart-large-cnn", 
                device=0 if torch.cuda.is_available() else -1,
                model_kwargs={"cache_dir": "./model_cache"},
                truncation=True,
                framework="pt",
                # Disable progress bar
                disable_tqdm=True
            )
            print("BART-Large-CNN loaded for summarization.")
        except Exception as e:
            print(f"Warning: Could not load BART-Large-CNN: {e}")
            self.summarizer = None
    
    def process_video(self):
        # Video setup
        cap = cv2.VideoCapture(self.video_path)
        if not cap.isOpened():
            print(f"Error: Could not open video at {self.video_path}")
            return False
            
        self.total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if self.total_frames == 0:
            print("Error: Video has no frames.")
            return False
            
        self.middle_index = self.total_frames // 2
        self.fps = cap.get(cv2.CAP_PROP_FPS)
        frame_index = 0
        
        # Process each frame
        print("Processing video...")
        while cap.isOpened():
            success, frame = cap.read()
            if not success:
                break
                
            frame = cv2.resize(frame, (700, 500))
            results = self.model.track(frame, persist=True)
            
            if results[0].boxes is None or len(results[0].boxes) == 0:
                self.frame_vehicle_counts.append({})
                frame_index += 1
                continue
                
            boxes = results[0].boxes.xyxy.cpu().numpy()
            confidences = results[0].boxes.conf.cpu().numpy()
            classes = results[0].boxes.cls.cpu().numpy().astype(int)
            track_ids = results[0].boxes.id.cpu().numpy() if results[0].boxes.id is not None else []
            
            current_frame_counts = defaultdict(int)
            current_time = frame_index / self.fps
            
            for conf, cls, track_id, box in zip(confidences, classes, track_ids, boxes):
                if conf < 0.5:
                    continue
                label = results[0].names[cls]
                
                # Only count vehicles (not traffic lights, people, etc.)
                if label in ['car', 'truck', 'bus', 'motorcycle', 'bicycle', 'ambulance', 'police car']:
                    # Update track history for unique vehicle counting
                    if track_id not in self.track_history:
                        self.track_history[track_id]['vehicle_type'] = label
                        self.vehicle_counts[label] = self.vehicle_counts.get(label, 0) + 1
                    
                    self.track_history[track_id]['first_seen'] = min(self.track_history[track_id]['first_seen'], current_time)
                    self.track_history[track_id]['last_seen'] = max(self.track_history[track_id]['last_seen'], current_time)
                    
                    # Update track ID to class mapping
                    self.track_id_to_class[track_id] = label
                    
                    # Count for current frame
                    current_frame_counts[label] += 1
                    
                    # Store position for speed calculation
                    center_x = (box[0] + box[2]) / 2
                    center_y = (box[1] + box[3]) / 2
                    self.track_positions[track_id].append((frame_index, center_x, center_y))
            
            self.frame_vehicle_counts.append(dict(current_frame_counts))
            
            # Check for emergency vehicles
            if current_frame_counts.get("ambulance", 0) > 0:
                alert = f"Ambulance activity at {frame_index/self.fps:.2f}s: {current_frame_counts['ambulance']} ambulance(s)"
                self.emergency_alerts.append(alert)
            if current_frame_counts.get("police car", 0) > 0:
                alert = f"Police car activity at {frame_index/self.fps:.2f}s: {current_frame_counts['police car']} police car(s)"
                self.emergency_alerts.append(alert)
            
            # Calculate congestion
            total_in_frame = sum(current_frame_counts.values())
            congestion_index = total_in_frame / 5.0  # Normalize to reasonable scale
            self.congestion_indices.append(congestion_index)
            
            # Track peak traffic
            if total_in_frame > self.max_vehicles:
                self.max_vehicles = total_in_frame
                self.max_frame_index = frame_index
                self.peak_frame = results[0].plot()
            
            # Save middle frame
            if frame_index == self.middle_index:
                self.selected_frame = results[0].plot()
            
            frame_index += 1
        
        cap.release()
        cv2.destroyAllWindows()
        print(f"Video processing complete: {self.total_frames} frames processed.")
        
        # Calculate average speeds
        print("Calculating average speeds...")
        self.calculate_average_speeds()
        
        # Save annotated frames
        self.save_frames()
        
        # Calculate statistics and create context
        self.calculate_statistics()
        
        # Save data to JSON for accurate retrieval later
        with open('traffic_data.json', 'w') as f:
            json.dump({
                'vehicle_counts': self.vehicle_counts,
                'average_counts': {k: float(v) for k, v in self.average_counts.items()},
                'average_speeds': {k: float(v) for k, v in self.average_speeds.items()},
                'max_vehicles': self.max_vehicles,
                'max_time_sec': self.max_time_sec,
                'total_frames': self.total_frames,
                'fps': float(self.fps),
                'emergency_alerts': self.emergency_alerts,
                'middle_index': self.middle_index,
                'middle_frame_counts': self.frame_vehicle_counts[self.middle_index] if self.middle_index < len(self.frame_vehicle_counts) else {}
            }, f)
        
        # Export to CSV
        self.export_to_csv()
        
        # Generate visualizations
        self.generate_visualizations()
        
        # Generate report
        self.generate_enhanced_report()
        
        self.processed = True
        return True
        
    def calculate_average_speeds(self):
        for track_id, positions in self.track_positions.items():
            label = self.track_id_to_class[track_id]
            if len(positions) < 2:
                continue
            total_speed = 0
            count = 0
            for i in range(1, len(positions)):
                frame_diff = positions[i][0] - positions[i-1][0]
                if frame_diff == 0:
                    continue
                dx = positions[i][1] - positions[i-1][1]
                dy = positions[i][2] - positions[i-1][2]
                distance = np.sqrt(dx**2 + dy**2)
                time = frame_diff / self.fps
                speed = distance / time
                total_speed += speed
                count += 1
            if count > 0:
                if label not in self.average_speeds:
                    self.average_speeds[label] = 0
                    self.average_speeds[f"{label}_count"] = 0
                self.average_speeds[label] += total_speed / count
                self.average_speeds[f"{label}_count"] += 1
        
        for label in self.vehicle_counts.keys():
            count_key = f"{label}_count"
            if count_key in self.average_speeds:
                self.average_speeds[label] = self.average_speeds[label] / self.average_speeds[count_key]
                del self.average_speeds[count_key]
    
    def save_frames(self):
        try:
            if self.selected_frame is not None:
                cv2.imwrite("middle_frame.jpg", self.selected_frame)
                print("Middle frame saved as 'middle_frame.jpg'.")
            if self.peak_frame is not None:
                cv2.imwrite("peak_frame.jpg", self.peak_frame)
                print("Peak frame saved as 'peak_frame.jpg'.")
        except Exception as e:
            print(f"Error saving frames: {e}")
    
    def calculate_statistics(self):
        # Calculate averages and max time
        for vtype in self.vehicle_counts.keys():
            total = sum(frame_counts.get(vtype, 0) for frame_counts in self.frame_vehicle_counts)
            self.average_counts[vtype] = total / len(self.frame_vehicle_counts) if self.frame_vehicle_counts else 0
        
        self.max_time_sec = self.max_frame_index / self.fps if self.fps > 0 else 0
        average_congestion = np.mean(self.congestion_indices) if self.congestion_indices else 0
        
        # First section, middle section, last section vehicle counts
        first_section = sum(sum(fc.values()) for fc in self.frame_vehicle_counts[:len(self.frame_vehicle_counts)//4])
        middle_section = sum(sum(fc.values()) for fc in self.frame_vehicle_counts[len(self.frame_vehicle_counts)//4:3*len(self.frame_vehicle_counts)//4])
        last_section = sum(sum(fc.values()) for fc in self.frame_vehicle_counts[3*len(self.frame_vehicle_counts)//4:])
        
        # Create context string
        self.context = (
            f"- Video duration: {(self.total_frames / self.fps):.2f} seconds\n"
            f"- Unique vehicles: {', '.join([f'{v}: {c}' for v, c in self.vehicle_counts.items()])} (total: {sum(self.vehicle_counts.values())})\n"
            f"- Average vehicles per frame: {', '.join([f'{v}: {c:.2f}' for v, c in self.average_counts.items()])}\n"
            f"- Average speeds (pixels/s): {', '.join([f'{v}: {s:.2f}' for v, s in self.average_speeds.items()])}\n"
            f"- Peak traffic: {self.max_vehicles} vehicles at {self.max_time_sec:.2f} seconds\n"
            f"- Middle frame (at {(self.middle_index / self.fps):.2f} seconds): {', '.join([f'{v}: {c}' for v, c in self.frame_vehicle_counts[self.middle_index].items()]) if self.middle_index < len(self.frame_vehicle_counts) else 'no vehicles'}\n"
            f"- Average congestion index: {average_congestion:.2f} (0=low, 1=moderate, >2=high)\n"
            f"- Emergency alerts: {'; '.join(self.emergency_alerts) if self.emergency_alerts else 'None'}\n"
            f"- Temporal trends: First 25%: {first_section} vehicles, "
            f"Middle 50%: {middle_section} vehicles, "
            f"Last 25%: {last_section} vehicles"
        )
        print("\n--- Context Generated ---\n")
        print(self.context)
    
    def export_to_csv(self):
        print("Exporting data to CSV...")
        times = [i / self.fps for i in range(len(self.frame_vehicle_counts))]
        total_vehicles_per_frame = [sum(frame_counts.values()) for frame_counts in self.frame_vehicle_counts]
        
        csv_data = {
            "Frame": list(range(len(self.frame_vehicle_counts))),
            "Time (s)": times,
            "Total Vehicles": total_vehicles_per_frame,
            "Congestion Index": self.congestion_indices
        }
        
        for vtype in self.vehicle_counts.keys():
            csv_data[vtype] = [frame_counts.get(vtype, 0) for frame_counts in self.frame_vehicle_counts]
        
        df = pd.DataFrame(csv_data)
        try:
            df.to_csv("traffic_data.csv", index=False)
            print("Traffic data exported to 'traffic_data.csv'.")
        except Exception as e:
            print(f"Error exporting CSV: {e}")
    
    def generate_visualizations(self):
        print("Generating visualizations...")
        
        # 1. Traffic Density Heatmap
        self.generate_heatmap()
        
        # 2. Vehicle Count Timeline
        self.generate_timeline_graph()
        
        # 3. Vehicle Type Distribution Pie Chart
        self.generate_vehicle_distribution_chart()
        
        # 4. Congestion Index Over Time
        self.generate_congestion_graph()
        
        # 5. Vehicle Speed Comparison
        self.generate_speed_comparison()

    def generate_heatmap(self):
        print("Generating traffic density heatmap...")
        times = [i / self.fps for i in range(len(self.frame_vehicle_counts))]
        total_vehicles_per_frame = [sum(frame_counts.values()) for frame_counts in self.frame_vehicle_counts]
        
        plt.figure(figsize=(10, 4))
        sns.heatmap([total_vehicles_per_frame], cmap="YlOrRd", xticklabels=50, cbar_kws={'label': 'Vehicle Count'})
        plt.xlabel("Time (seconds)")
        plt.ylabel("Density")
        plt.title("Traffic Density Heatmap")
        plt.xticks(ticks=np.linspace(0, len(times)-1, 5), labels=[f"{t:.1f}" for t in np.linspace(0, max(times), 5)])
        try:
            plt.savefig("heatmap.png", dpi=300, bbox_inches='tight')
            plt.close()
            print("Heatmap saved as 'heatmap.png'.")
        except Exception as e:
            print(f"Error saving heatmap: {e}")
    
    def generate_timeline_graph(self):
        print("Generating vehicle count timeline...")
        times = [i / self.fps for i in range(len(self.frame_vehicle_counts))]
        total_vehicles_per_frame = [sum(frame_counts.values()) for frame_counts in self.frame_vehicle_counts]
        
        # Create a figure with a specific size
        plt.figure(figsize=(12, 6))
        
        # Plot the data with a thicker line and markers
        plt.plot(times, total_vehicles_per_frame, linewidth=2, marker='o', markersize=3, alpha=0.7)
        
        # Add a trend line (moving average)
        window_size = max(1, len(total_vehicles_per_frame) // 20)  # 5% of total frames
        if window_size > 1:
            moving_avg = np.convolve(total_vehicles_per_frame, np.ones(window_size)/window_size, mode='valid')
            plt.plot(times[window_size-1:], moving_avg, 'r-', linewidth=3, label=f'Trend (Moving Avg, window={window_size})')
        
        # Highlight peak traffic point
        peak_idx = total_vehicles_per_frame.index(max(total_vehicles_per_frame))
        plt.plot(times[peak_idx], total_vehicles_per_frame[peak_idx], 'ro', markersize=10, label='Peak Traffic')
        
        # Add labels and title
        plt.xlabel('Time (seconds)')
        plt.ylabel('Number of Vehicles')
        plt.title('Vehicle Count Over Time')
        plt.grid(True, alpha=0.3)
        plt.legend()
        
        # Add annotations for key points
        plt.annotate(f'Peak: {total_vehicles_per_frame[peak_idx]} vehicles', 
                    xy=(times[peak_idx], total_vehicles_per_frame[peak_idx]),
                    xytext=(times[peak_idx]+0.5, total_vehicles_per_frame[peak_idx]+1),
                    arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8))
        
        # Save the figure
        try:
            plt.savefig("vehicle_timeline.png", dpi=300, bbox_inches='tight')
            plt.close()
            print("Vehicle timeline saved as 'vehicle_timeline.png'.")
        except Exception as e:
            print(f"Error saving vehicle timeline: {e}")
    
    def generate_vehicle_distribution_chart(self):
        print("Generating vehicle type distribution chart...")
        
        # Create a pie chart of vehicle types
        plt.figure(figsize=(10, 8))
        
        # Extract data
        labels = list(self.vehicle_counts.keys())
        sizes = list(self.vehicle_counts.values())
        
        # Add percentage to labels
        total = sum(sizes)
        labels = [f'{l} ({s}, {s/total*100:.1f}%)' for l, s in zip(labels, sizes)]
        
        # Create explode list to emphasize the largest segment
        explode = [0.1 if s == max(sizes) else 0 for s in sizes]
        
        # Create pie chart with shadow and explosion
        plt.pie(sizes, explode=explode, labels=labels, autopct='%1.1f%%',
                shadow=True, startangle=90, textprops={'fontsize': 12})
        plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle
        plt.title('Distribution of Vehicle Types', fontsize=16)
        
        # Save the figure
        try:
            plt.savefig("vehicle_distribution.png", dpi=300, bbox_inches='tight')
            plt.close()
            print("Vehicle distribution chart saved as 'vehicle_distribution.png'.")
        except Exception as e:
            print(f"Error saving vehicle distribution chart: {e}")
    
    def generate_congestion_graph(self):
        print("Generating congestion index graph...")
        times = [i / self.fps for i in range(len(self.congestion_indices))]
        
        plt.figure(figsize=(12, 6))
        plt.plot(times, self.congestion_indices, linewidth=2, color='purple')
        
        # Add horizontal lines for congestion levels
        plt.axhline(y=0.5, color='green', linestyle='--', alpha=0.7, label='Low Congestion')
        plt.axhline(y=1.0, color='orange', linestyle='--', alpha=0.7, label='Moderate Congestion')
        plt.axhline(y=2.0, color='red', linestyle='--', alpha=0.7, label='High Congestion')
        
        # Add labels and title
        plt.xlabel('Time (seconds)')
        plt.ylabel('Congestion Index')
        plt.title('Traffic Congestion Over Time')
        plt.grid(True, alpha=0.3)
        plt.legend()
        
        # Save the figure
        try:
            plt.savefig("congestion_index.png", dpi=300, bbox_inches='tight')
            plt.close()
            print("Congestion index graph saved as 'congestion_index.png'.")
        except Exception as e:
            print(f"Error saving congestion index graph: {e}")
    
    def generate_speed_comparison(self):
        print("Generating vehicle speed comparison...")
        
        if not self.average_speeds:
            print("No speed data available.")
            return
        
        # Create bar chart of average speeds
        plt.figure(figsize=(10, 6))
        
        # Extract data
        vehicle_types = list(self.average_speeds.keys())
        speeds = list(self.average_speeds.values())
        
        # Create bar chart
        bars = plt.bar(vehicle_types, speeds, color='skyblue')
        
        # Add value labels on top of bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{height:.1f}',
                    ha='center', va='bottom', rotation=0)
        
        # Add labels and title
        plt.xlabel('Vehicle Type')
        plt.ylabel('Average Speed (pixels/second)')
        plt.title('Average Speed by Vehicle Type')
        plt.grid(True, alpha=0.3, axis='y')
        plt.xticks(rotation=45)
        
        # Save the figure
        try:
            plt.savefig("speed_comparison.png", dpi=300, bbox_inches='tight')
            plt.close()
            print("Speed comparison chart saved as 'speed_comparison.png'.")
        except Exception as e:
            print(f"Error saving speed comparison chart: {e}")
    
    def generate_creative_text(self, prompt):
        if self.text_generator:
            try:
                result = self.text_generator(
                    prompt, 
                    max_new_tokens=300,
                    do_sample=True,
                    temperature=0.7,
                    no_repeat_ngram_size=3
                )
                return result[0]['generated_text'].strip()
            except Exception as e:
                print(f"Error generating text: {e}")
                return "Could not generate creative text."
        else:
            return "Creative text generation not available (NLP models not loaded)."
    
    def enhance_text(self, text):
        if self.summarizer:
            try:
                # Calculate appropriate max_length based on input length
                input_length = len(text.split())
                # For summarization, output should be shorter than input
                max_length = min(input_length - 1, 300)  # Cap at 300 tokens
                min_length = min(max(30, input_length // 4), 150)  # Between 30 and 150
                
                enhanced = self.summarizer(
                    text, 
                    max_length=max_length, 
                    min_length=min_length, 
                    do_sample=True,
                    truncation=True
                )[0]['summary_text']
                return enhanced
            except Exception as e:
                print(f"Error enhancing text: {e}")
                return text
        else:
            return text
    
    def generate_descriptive_answer(self, question):
        """
        Generate a concise, factual description of the traffic without using summarization pipeline
        """
        # Get frame counts at middle of video for a representative snapshot
        middle_time = self.middle_index / self.fps if self.fps > 0 else 0
        middle_frame_counts = self.frame_vehicle_counts[self.middle_index] if self.middle_index < len(self.frame_vehicle_counts) else {}
        
        # Get total unique vehicles
        total_unique = sum(self.vehicle_counts.values())
        
        # Construct a simple description
        description = f"The video shows traffic with {total_unique} unique vehicles over {(self.total_frames / self.fps):.2f} seconds. "
        
        # Add current frame info
        if middle_frame_counts:
            description += f"At {middle_time:.2f} seconds, there are "
            description += ", ".join([f"{count} {vtype}(s)" for vtype, count in middle_frame_counts.items()])
            description += ". "
        
        # Add peak traffic info
        description += f"Peak traffic occurs at {self.max_time_sec:.2f} seconds with {self.max_vehicles} vehicles. "
        
        # Add vehicle type distribution
        if self.vehicle_counts:
            description += f"Vehicle distribution: "
            description += ", ".join([f"{count} {vtype}(s)" for vtype, count in self.vehicle_counts.items()])
            description += "."
        
        return description
    
    def generate_enhanced_report(self):
        print("Generating enhanced PDF report...")
        
        # Generate text content for different sections
        flow_prompt = (
            f"Context:\n{self.context}\n"
            "Instruction: Generate a detailed traffic flow analysis starting with 'The traffic flow analysis reveals...'. "
            "Include patterns of congestion, peak times, and vehicle distribution."
        )
        flow_text = self.generate_creative_text(flow_prompt)
        
        behavior_prompt = (
            f"Context:\n{self.context}\n"
            "Instruction: Generate a detailed description of vehicle behavior starting with 'Vehicle behavior analysis shows...'. "
            "Include speed patterns, types of vehicles, and any unusual events."
        )
        behavior_text = self.generate_creative_text(behavior_prompt)
        
        recommendations_prompt = (
            f"Context:\n{self.context}\n"
            "Instruction: Generate 3-4 traffic management recommendations based on this data, "
            "starting with 'Based on the analysis, we recommend...'"
        )
        recommendations_text = self.generate_creative_text(recommendations_prompt)
        
        insights_prompt = (
            f"Context:\n{self.context}\n"
            "Instruction: Generate 5-6 key insights from the traffic data, focusing on patterns, anomalies, and notable observations. "
            "Format as bullet points."
        )
        insights_text = self.generate_creative_text(insights_prompt)
        
        # Create PDF
        try:
            # Set up the document
            doc = SimpleDocTemplate("traffic_analysis_report.pdf", pagesize=A4)
            styles = getSampleStyleSheet()
            
            # Create custom styles
            title_style = ParagraphStyle(
                'Title',
                parent=styles['Title'],
                fontSize=24,
                spaceAfter=12,
                textColor=colors.darkblue
            )
            
            heading1_style = ParagraphStyle(
                'Heading1',
                parent=styles['Heading1'],
                fontSize=18,
                spaceAfter=10,
                textColor=colors.darkblue
            )
            
            heading2_style = ParagraphStyle(
                'Heading2',
                parent=styles['Heading2'],
                fontSize=14,
                spaceAfter=8,
                textColor=colors.darkblue
            )
            
            normal_style = ParagraphStyle(
                'Normal',
                parent=styles['Normal'],
                fontSize=10,
                leading=14,
                spaceAfter=8,
                alignment=TA_JUSTIFY
            )
            
            caption_style = ParagraphStyle(
                'Caption',
                parent=styles['Normal'],
                fontSize=9,
                leading=12,
                alignment=TA_CENTER,
                textColor=colors.darkslategray
            )
            
            # Create document elements
            elements = []
            
            # Cover page
            elements.append(Paragraph("Traffic Analysis Report", title_style))
            elements.append(Spacer(1, 0.2 * inch))
            
            # Add current date
            current_date = datetime.now().strftime("%B %d, %Y")
            elements.append(Paragraph(f"Generated on: {current_date}", normal_style))
            elements.append(Spacer(1, 0.2 * inch))
            
            # Add video information
            video_name = os.path.basename(self.video_path)
            elements.append(Paragraph(f"Analysis of: {video_name}", normal_style))
            elements.append(Paragraph(f"Duration: {self.total_frames / self.fps:.2f} seconds", normal_style))
            elements.append(Paragraph(f"Total unique vehicles detected: {sum(self.vehicle_counts.values())}", normal_style))
            
            # Add peak frame image
            if os.path.exists("peak_frame.jpg"):
                elements.append(Spacer(1, 0.3 * inch))
                elements.append(Paragraph("Peak Traffic Frame:", heading2_style))
                img = ReportLabImage("peak_frame.jpg", width=6*inch, height=4*inch)
                elements.append(img)
                elements.append(Paragraph(f"Peak traffic occurred at {self.max_time_sec:.2f} seconds with {self.max_vehicles} vehicles", caption_style))
            
            # Add page break
            elements.append(PageBreak())
            
            # Table of Contents
            elements.append(Paragraph("Table of Contents", heading1_style))
            elements.append(Spacer(1, 0.2 * inch))
            
            toc_data = [
                ["1. Executive Summary", "3"],
                ["2. Traffic Flow Analysis", "4"],
                ["3. Vehicle Behavior", "5"],
                ["4. Key Insights", "6"],
                ["5. Visualizations", "7"],
                ["6. Statistical Summary", "9"],
                ["7. Recommendations", "10"]
            ]
            
            toc_table = Table(toc_data, colWidths=[5*inch, 0.5*inch])
            toc_table.setStyle(TableStyle([
                ('FONT', (0, 0), (-1, -1), 'Helvetica'),
                ('FONTSIZE', (0, 0), (-1, -1), 11),
                ('BOTTOMPADDING', (0, 0), (-1, -1), 10),
                ('RIGHTPADDING', (0, 0), (0, -1), 20),
                ('ALIGN', (1, 0), (1, -1), 'RIGHT'),
            ]))
            elements.append(toc_table)
            elements.append(PageBreak())
            
            # Executive Summary
            elements.append(Paragraph("1. Executive Summary", heading1_style))
            elements.append(HorizontalLine(450))
            elements.append(Spacer(1, 0.2 * inch))
            
            summary_text = self.generate_descriptive_answer("Describe the traffic in the video")
            elements.append(Paragraph(summary_text, normal_style))
            
            # Add middle frame image
            if os.path.exists("middle_frame.jpg"):
                elements.append(Spacer(1, 0.3 * inch))
                elements.append(Paragraph("Representative Frame (Middle of Video):", heading2_style))
                img = ReportLabImage("middle_frame.jpg", width=6*inch, height=4*inch)
                elements.append(img)
                middle_time = self.middle_index / self.fps
                elements.append(Paragraph(f"Frame captured at {middle_time:.2f} seconds", caption_style))
            
            elements.append(PageBreak())
            
            # Traffic Flow Analysis
            elements.append(Paragraph("2. Traffic Flow Analysis", heading1_style))
            elements.append(HorizontalLine(450))
            elements.append(Spacer(1, 0.2 * inch))
            
            elements.append(Paragraph(flow_text, normal_style))
            
            # Add timeline graph
            if os.path.exists("vehicle_timeline.png"):
                elements.append(Spacer(1, 0.3 * inch))
                elements.append(Paragraph("Vehicle Count Over Time:", heading2_style))
                img = ReportLabImage("vehicle_timeline.png", width=6*inch, height=3*inch)
                elements.append(img)
                elements.append(Paragraph("This graph shows how the number of vehicles changes throughout the video duration", caption_style))
            
            elements.append(PageBreak())
            
            # Vehicle Behavior
            elements.append(Paragraph("3. Vehicle Behavior", heading1_style))
            elements.append(HorizontalLine(450))
            elements.append(Spacer(1, 0.2 * inch))
            
            elements.append(Paragraph(behavior_text, normal_style))
            
            # Add vehicle distribution chart
            if os.path.exists("vehicle_distribution.png"):
                elements.append(Spacer(1, 0.3 * inch))
                elements.append(Paragraph("Vehicle Type Distribution:", heading2_style))
                img = ReportLabImage("vehicle_distribution.png", width=5*inch, height=4*inch)
                elements.append(img)
                elements.append(Paragraph("Distribution of different vehicle types detected in the video", caption_style))
            
            elements.append(PageBreak())
            
            # Key Insights
            elements.append(Paragraph("4. Key Insights", heading1_style))
            elements.append(HorizontalLine(450))
            elements.append(Spacer(1, 0.2 * inch))
            
            # Process insights text into bullet points
            insights_lines = insights_text.split('\n')
            for line in insights_lines:
                line = line.strip()
                if line:
                    if line.startswith('-'):
                        line = line[1:].strip()
                    elements.append(Paragraph(f"• {line}", normal_style))
                    elements.append(Spacer(1, 0.1 * inch))
            
            # Add congestion graph
            if os.path.exists("congestion_index.png"):
                elements.append(Spacer(1, 0.3 * inch))
                elements.append(Paragraph("Traffic Congestion Over Time:", heading2_style))
                img = ReportLabImage("congestion_index.png", width=6*inch, height=3*inch)
                elements.append(img)
                elements.append(Paragraph("This graph shows how traffic congestion varies throughout the video", caption_style))
            
            elements.append(PageBreak())
            
            # Visualizations
            elements.append(Paragraph("5. Visualizations", heading1_style))
            elements.append(HorizontalLine(450))
            elements.append(Spacer(1, 0.2 * inch))
            
            # Add heatmap
            if os.path.exists("heatmap.png"):
                elements.append(Paragraph("Traffic Density Heatmap:", heading2_style))
                img = ReportLabImage("heatmap.png", width=6*inch, height=2.5*inch)
                elements.append(img)
                elements.append(Paragraph("Heatmap showing the intensity of traffic over time", caption_style))
                elements.append(Spacer(1, 0.3 * inch))
            
            # Add speed comparison
            if os.path.exists("speed_comparison.png"):
                elements.append(Paragraph("Average Speed by Vehicle Type:", heading2_style))
                img = ReportLabImage("speed_comparison.png", width=6*inch, height=3*inch)
                elements.append(img)
                elements.append(Paragraph("Comparison of average speeds for different vehicle types", caption_style))
            
            elements.append(PageBreak())
            
            # Statistical Summary
            elements.append(Paragraph("6. Statistical Summary", heading1_style))
            elements.append(HorizontalLine(450))
            elements.append(Spacer(1, 0.2 * inch))
            
            # Vehicle counts table
            elements.append(Paragraph("Vehicle Counts:", heading2_style))
            
            vehicle_data = [['Vehicle Type', 'Unique Count', 'Avg. per Frame', 'Avg. Speed (px/s)']]
            for vtype in self.vehicle_counts.keys():
                vehicle_data.append([
                    vtype, 
                    str(self.vehicle_counts.get(vtype, 0)), 
                    f"{self.average_counts.get(vtype, 0):.2f}", 
                    f"{self.average_speeds.get(vtype, 0):.2f}"
                ])
            
            # Add total row
            vehicle_data.append([
                'Total', 
                str(sum(self.vehicle_counts.values())), 
                f"{sum(self.average_counts.values()):.2f}", 
                'N/A'
            ])
            
            vehicle_table = Table(vehicle_data, colWidths=[1.5*inch, 1.2*inch, 1.5*inch, 1.5*inch])
            vehicle_table.setStyle(TableStyle([
                ('BACKGROUND', (0, 0), (-1, 0), colors.darkblue),
                ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
                ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, 0), 11),
                ('BOTTOMPADDING', (0, 0), (-1, 0), 12),
                ('BACKGROUND', (0, -1), (-1, -1), colors.lightgrey),
                ('FONTNAME', (0, -1), (-1, -1), 'Helvetica-Bold'),
                ('GRID', (0, 0), (-1, -1), 1, colors.black),
                ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
            ]))
            elements.append(vehicle_table)
            elements.append(Spacer(1, 0.3 * inch))
            
            # Key metrics table
            elements.append(Paragraph("Key Metrics:", heading2_style))
            
            metrics_data = [
                ['Metric', 'Value'],
                ['Video Duration', f"{self.total_frames / self.fps:.2f} seconds"],
                ['Peak Traffic Time', f"{self.max_time_sec:.2f} seconds"],
                ['Peak Vehicle Count', f"{self.max_vehicles} vehicles"],
                ['Average Congestion Index', f"{np.mean(self.congestion_indices):.2f}/5.0"],
                ['Emergency Alerts', f"{len(self.emergency_alerts)}"]
            ]
            
            metrics_table = Table(metrics_data, colWidths=[2.5*inch, 3.5*inch])
            metrics_table.setStyle(TableStyle([
                ('BACKGROUND', (0, 0), (-1, 0), colors.darkblue),
                ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
                ('ALIGN', (0, 0), (-1, -1), 'LEFT'),
                ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
                ('FONTSIZE', (0, 0), (-1, 0), 11),
                ('BOTTOMPADDING', (0, 0), (-1, 0), 12),
                ('BACKGROUND', (0, 1), (0, -1), colors.lightgrey),
                ('GRID', (0, 0), (-1, -1), 1, colors.black),
                ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
            ]))
            elements.append(metrics_table)
            
            elements.append(PageBreak())
            
            # Recommendations
            elements.append(Paragraph("7. Recommendations", heading1_style))
            elements.append(HorizontalLine(450))
            elements.append(Spacer(1, 0.2 * inch))
            
            elements.append(Paragraph(recommendations_text, normal_style))
            
            # Build the document
            doc.build(elements)
            print("Enhanced PDF report generated as 'traffic_analysis_report.pdf'.")
        except Exception as e:
            print(f"Error generating enhanced PDF report: {e}")
    
    def get_unique_vehicles_in_timerange(self, start_time, end_time):
        """
        Count unique vehicles that appeared between start_time and end_time.
        """
        unique_vehicles = defaultdict(int)
        
        for track_id, data in self.track_history.items():
            # Check if the vehicle was present during the time range
            if data['first_seen'] <= end_time and data['last_seen'] >= start_time:
                vehicle_type = data['vehicle_type']
                if vehicle_type:
                    unique_vehicles[vehicle_type] += 1
        
        return dict(unique_vehicles)
    
    def answer_question(self, question):
        """
        Rules-based factual question answering from stored data, with creative descriptions
        where appropriate.
        """
        if not self.processed:
            return "Please process the video first before asking questions."
            
        if not question or not isinstance(question, str):
            return "Please ask a valid question."
            
        # Load saved data to ensure accuracy
        try:
            with open('traffic_data.json', 'r') as f:
                data = json.load(f)
        except:
            # Fall back to in-memory data if file not found
            data = {
                'vehicle_counts': self.vehicle_counts,
                'average_counts': self.average_counts,
                'average_speeds': self.average_speeds,
                'max_vehicles': self.max_vehicles,
                'max_time_sec': self.max_time_sec,
                'total_frames': self.total_frames,
                'fps': self.fps,
                'emergency_alerts': self.emergency_alerts,
                'middle_index': self.middle_index,
                'middle_frame_counts': self.frame_vehicle_counts[self.middle_index] if self.middle_index < len(self.frame_vehicle_counts) else {}
            }
            
        # Normalize question
        question = question.lower().strip()
        tokens = word_tokenize(question)
        stop_words = set(stopwords.words('english'))
        filtered_tokens = [w for w in tokens if w.isalnum() and w not in stop_words]
        
        # Special handling for "describe the traffic" to avoid summarization issues
        if ("describe" in question and "traffic" in question) or ("summarize" in question and "traffic" in question):
            return self.generate_descriptive_answer(question)
        
        # Time range handling - IMPROVED for unique vehicle counting
        time_range_match = re.search(r"from (\d+(?:\.\d+)?)\s*(?:to|s(?:econd)?s?\s*to)\s*(\d+(?:\.\d+)?)\s*s(?:econd)?s?", question)
        time_match = re.search(r"at (\d+(?:\.\d+)?) seconds?", question)
        
        if time_range_match:
            start_time = float(time_range_match.group(1))
            end_time = float(time_range_match.group(2))
            
            if start_time >= end_time or end_time > self.total_frames / self.fps:
                return f"Invalid time range: {start_time}s to {end_time}s. Video duration is {self.total_frames/self.fps:.2f}s."
                
            # Get unique vehicles in the time range
            unique_vehicles = self.get_unique_vehicles_in_timerange(start_time, end_time)
            total_unique = sum(unique_vehicles.values())
            
            # Generate response
            response = f"From {start_time}s to {end_time}s, {total_unique} unique vehicles passed: "
            response += ", ".join([f"{count} {vtype}(s)" for vtype, count in unique_vehicles.items()])
            return response
            
        elif time_match:
            time_sec = float(time_match.group(1))
            
            if time_sec > self.total_frames / self.fps:
                return f"Time {time_sec}s exceeds video duration ({self.total_frames/self.fps:.2f}s)."
                
            frame_index = min(int(time_sec * self.fps), self.total_frames - 1)
            frame_counts = self.frame_vehicle_counts[frame_index]
            
            if not frame_counts:
                return f"At {time_sec}s: No vehicles detected."
                
            response = f"At {time_sec}s: "
            response += ", ".join([f"{count} {vtype}(s)" for vtype, count in frame_counts.items()])
            return response
        
        # Total vehicle questions
        if any(w in filtered_tokens for w in ['total', 'all']) and any(w in filtered_tokens for w in ['vehicle', 'vehicles', 'cars', 'detected']):
            return f"Total unique vehicles detected: {sum(data['vehicle_counts'].values())}."
        
        # Vehicle type questions
        for vehicle_type in data['vehicle_counts'].keys():
            if vehicle_type.lower() in question:
                count = data['vehicle_counts'].get(vehicle_type, 0)
                avg = data['average_counts'].get(vehicle_type, 0)
                speed = data['average_speeds'].get(vehicle_type, 0)
                response = f"{count} {vehicle_type}(s) detected in total. "
                response += f"Average of {avg:.2f} per frame. "
                
                if speed > 0:
                    response += f"Average speed: {speed:.2f} pixels/second."
                return response
                
        # Peak traffic questions
        if any(w in filtered_tokens for w in ['peak', 'busiest', 'maximum']):
            if any(w in filtered_tokens for w in ['time', 'when']):
                return f"Peak traffic occurred at {data['max_time_sec']:.2f} seconds with {data['max_vehicles']} vehicles in frame."
            else:
                return f"Peak traffic was {data['max_vehicles']} vehicles at {data['max_time_sec']:.2f} seconds."
                
        # Middle of video questions
        if any(w in filtered_tokens for w in ['middle', 'mid']):
            middle_time = data['middle_index'] / data['fps']
            middle_counts = data.get('middle_frame_counts', {})
            
            if not middle_counts:
                return f"At the middle of the video ({middle_time:.2f}s): No vehicles detected."
                
            response = f"At the middle of the video ({middle_time:.2f}s): "
            response += ", ".join([f"{count} {vtype}(s)" for vtype, count in middle_counts.items()])
            return response
            
        # Duration question
        if any(w in filtered_tokens for w in ['duration', 'long', 'length']):
            return f"Video duration: {self.total_frames/self.fps:.2f} seconds."
            
        # Types of vehicles
        if any(w in filtered_tokens for w in ['types', 'kind', 'kinds']) and any(w in filtered_tokens for w in ['vehicle', 'vehicles']):
            return f"Types of vehicles detected: {', '.join(data['vehicle_counts'].keys())}."
            
        # Emergency vehicles
        if any(w in filtered_tokens for w in ['emergency', 'ambulance', 'police']):
            if data['emergency_alerts']:
                return f"Emergency vehicle activity: {'; '.join(data['emergency_alerts'])}"
            else:
                return "No emergency vehicles were detected in the video."
                
        # Fall back to context info
        return f"Here's what I know about the video:\n{self.context}"
        
    def run_chat_loop(self):
        """Run an interactive chat loop to answer questions about the video."""
        print("\nVideo processing complete. Ask any question about the video (type 'exit' to quit).")
        
        # Generate suggestions based on detected vehicles
        suggestions = [
            "How many vehicles were detected in total?",
            "What types of vehicles were detected?",
            "What was the peak traffic time?",
            "How many unique vehicles passed from 2 to 5 seconds?",
            "Describe the traffic in the video.",
            "What was happening in the middle of the video?"
        ]
        
        # Add vehicle-specific suggestions
        for vehicle_type in self.vehicle_counts.keys():
            suggestions.append(f"How many {vehicle_type}s were detected?")
            
        # Add emergency vehicle suggestions if applicable
        if self.emergency_alerts:
            suggestions.append("Were there any emergency vehicles?")
            
        print("Suggested questions:")
        for i, suggestion in enumerate(suggestions[:8], 1):  # Limit to 8 suggestions
            print(f"{i}. {suggestion}")
            
        while True:
            try:
                user_query = input("\nAsk a question (or type 'exit' to quit): ")
                if user_query.lower().strip() == "exit":
                    print("Goodbye.")
                    break
                    
                response = self.answer_question(user_query)
                print("\nAnswer:")
                print(response)
                
            except KeyboardInterrupt:
                print("\nChat interrupted. Type 'exit' to quit or continue.")
            except Exception as e:
                print(f"Error in chat loop: {e}")

def main():
    model_path = "/kaggle/input/testing/yolov8x.pt"  # Update with your model path
    video_path = "/kaggle/input/testing/test_30_Multi - Made with Clipchamp.mp4"  # Update with your video path
    
    analyzer = TrafficAnalyzer(model_path, video_path)
    if analyzer.process_video():
        analyzer.run_chat_loop()
    else:
        print("Failed to process video.")
    
    # Cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

if __name__ == "__main__":
    main()


YOLO model loaded successfully.
Flan-T5-Base loaded for text generation.
BART-Large-CNN loaded for summarization.
Processing video...

0: 480x640 9 cars, 74.7ms
Speed: 2.5ms preprocess, 74.7ms inference, 1.3ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 8 cars, 68.7ms
Speed: 2.1ms preprocess, 68.7ms inference, 1.2ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 8 cars, 44.1ms
Speed: 2.3ms preprocess, 44.1ms inference, 1.2ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 8 cars, 43.1ms
Speed: 2.2ms preprocess, 43.1ms inference, 1.4ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 8 cars, 44.6ms
Speed: 2.7ms preprocess, 44.6ms inference, 1.2ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 8 cars, 47.1ms
Speed: 2.2ms preprocess, 47.1ms inference, 1.2ms postprocess per image at shape (1, 3, 480, 640)

0: 480x640 8 cars, 44.0ms
Speed: 2.3ms preprocess, 44.0ms inference, 1.2ms postprocess per image at shape (1, 3, 480, 640


Ask a question (or type 'exit' to quit):  Give me summary of the report



Answer:
Here's what I know about the video:
- Video duration: 30.00 seconds
- Unique vehicles: car: 51, truck: 9, bus: 4 (total: 64)
- Average vehicles per frame: car: 6.29, truck: 0.82, bus: 0.08
- Average speeds (pixels/s): car: 87.65, truck: 111.03, bus: 42.94
- Peak traffic: 12 vehicles at 27.73 seconds
- Middle frame (at 15.00 seconds): car: 4, truck: 1
- Average congestion index: 1.44 (0=low, 1=moderate, >2=high)
- Emergency alerts: None
- Temporal trends: First 25%: 1715 vehicles, Middle 50%: 3007 vehicles, Last 25%: 1750 vehicles



Ask a question (or type 'exit' to quit):  how many vehicles passed from 2 to 5 seconds



Answer:
From 2.0s to 5.0s, 16 unique vehicles passed: 15 car(s), 1 truck(s)



Ask a question (or type 'exit' to quit):  how many cars were detected



Answer:
51 car(s) detected in total. Average of 6.29 per frame. Average speed: 87.65 pixels/second.
