In [2]:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
import cv2
import time
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import ipywidgets as widgets
from ipywidgets import interact
import os

# Class labels
class_names = ['herb paris', 'karela', 'small weed', 'grass', 'tori', 'horseweed', 'Bhindi', 'weed']

class TRTInference:
    def __init__(self, engine_path):
        # Load TRT engine
        self.logger = trt.Logger(trt.Logger.WARNING)
        with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
            self.engine = runtime.deserialize_cuda_engine(f.read())
        
        self.context = self.engine.create_execution_context()
        
        # Print engine information
        print(f"Number of bindings: {self.engine.num_bindings}")
        
        # Allocate memory for input/output bindings
        self.inputs = []
        self.outputs = []
        self.allocations = []
        
        for i in range(self.engine.num_bindings):
            is_input = self.engine.binding_is_input(i)
            name = self.engine.get_binding_name(i)
            dtype = self.engine.get_binding_dtype(i)
            shape = self.engine.get_binding_shape(i)
            
            # Print binding information
            print(f"Binding {i}: name={name}, is_input={is_input}, shape={shape}, dtype={dtype}")
            
            # Get the size in bytes for the data type
            if dtype == trt.DataType.FLOAT:
                dtype_size = 4  # 4 bytes for float32
            elif dtype == trt.DataType.HALF:
                dtype_size = 2  # 2 bytes for float16
            elif dtype == trt.DataType.INT8:
                dtype_size = 1  # 1 byte for int8
            else:
                dtype_size = 4  # Default to 4 bytes
                
            # Calculate total size
            size = trt.volume(shape) * dtype_size
            
            # Allocate CUDA memory
            allocation = cuda.mem_alloc(size)
            
            if is_input:
                self.inputs.append({"index": i, "name": name, "dtype": dtype, "shape": shape, "allocation": allocation})
            else:
                self.outputs.append({"index": i, "name": name, "dtype": dtype, "shape": shape, "allocation": allocation})
            
            self.allocations.append(allocation)
    
    def infer(self, img_input):
        try:
            # Copy input data to GPU
            cuda.memcpy_htod(self.inputs[0]["allocation"], img_input.astype(np.float32).ravel())
            
            # Run inference
            try:
                # Try the newer API first
                self.context.execute_v2(self.allocations)
            except AttributeError:
                # Fall back to older API if execute_v2 is not available
                print("Falling back to execute_async...")
                self.context.execute_async(batch_size=1, bindings=self.allocations, stream_handle=cuda.Stream().handle)
            
            # Copy output back to CPU
            output_shape = self.outputs[0]["shape"]
            output = np.zeros(output_shape, dtype=np.float32)
            cuda.memcpy_dtoh(output, self.outputs[0]["allocation"])
            
            return [output]
        except Exception as e:
            print(f"Error during inference: {e}")
            import traceback
            traceback.print_exc()
            # Return an empty array with the expected shape
            return [np.zeros(self.outputs[0]["shape"], dtype=np.float32)]

# Preprocess image
def preprocess(img_path, img_size=640):
    img = cv2.imread(img_path)
    if img is None:
        print(f"Error: Image not found at {img_path}")
        return None
    img = cv2.resize(img, (img_size, img_size))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32) / 255.0
    img = np.transpose(img, (2, 0, 1))  # HWC to CHW
    img = np.expand_dims(img, axis=0)  # Add batch dimension
    return img

# Postprocess with NMS
def postprocess(predictions, conf_thres=0.25, iou_thres=0.45):
    try:
        preds = np.squeeze(predictions[0])  # shape: (N, 85)
        
        # Print shape for debugging
        print(f"Output shape: {preds.shape}")
        
        # Handle different output formats (some TensorRT engines may have different shapes)
        if len(preds.shape) == 1:
            # If output is flattened, reshape it based on YOLOv5 output format
            num_classes = len(class_names)
            num_boxes = preds.shape[0] // (5 + num_classes)
            preds = preds.reshape(num_boxes, 5 + num_classes)
            print(f"Reshaped to: {preds.shape}")
        
        # Ensure we have the correct number of columns
        if preds.shape[1] < 5 + len(class_names):
            print(f"Warning: Output shape {preds.shape} doesn't match expected format")
            return []
        
        boxes = preds[:, :4]
        objectness = preds[:, 4]
        class_probs = preds[:, 5:5+len(class_names)]  # Only take as many classes as we have names for
        class_ids = np.argmax(class_probs, axis=1)
        class_scores = class_probs[np.arange(len(class_ids)), class_ids]
        scores = objectness * class_scores
        
        results = []
        for box, score, cls in zip(boxes, scores, class_ids):
            if score > conf_thres:
                cx, cy, w, h = box
                x = int(cx - w / 2)
                y = int(cy - h / 2)
                results.append(([x, y, int(w), int(h)], float(score), int(cls)))
        
        # Apply NMS
        if not results:
            return []
        
        boxes_xywh = [r[0] for r in results]
        scores = [r[1] for r in results]
        indices = cv2.dnn.NMSBoxes(boxes_xywh, scores, score_threshold=conf_thres, nms_threshold=iou_thres)
        
        # Handle different return formats of NMSBoxes (OpenCV version differences)
        if len(indices) > 0 and isinstance(indices[0], (list, tuple, np.ndarray)):
            indices = [i[0] for i in indices]
        
        return [results[i] for i in indices]
    
    except Exception as e:
        print(f"Error in postprocessing: {e}")
        import traceback
        traceback.print_exc()
        return []

def process_image(trt_model, img_num):
    try:
        img_path = f"images/{img_num}.jpg"
        print(f"Processing image: {img_path}")
        
        # Preprocess image
        img_input = preprocess(img_path)
        if img_input is None:
            return None
        
        # Measure inference time
        start_time = time.time()
        predictions = trt_model.infer(img_input)
        inference_time = (time.time() - start_time) * 1000  # Convert to milliseconds
        print(f"TensorRT Inference Time: {inference_time:.2f} ms")
        
        # Process results
        results = postprocess(predictions)
        print(f"Found {len(results)} detections")
        
        # Draw results
        original = cv2.imread(img_path)
        original_rgb = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)  # Convert to RGB for matplotlib
        original_resized = cv2.resize(original_rgb, (640, 640))
        
        # Create a copy for drawing
        img_with_boxes = original_resized.copy()
        
        for (x, y, w, h), score, cls in results:
            cv2.rectangle(img_with_boxes, (x, y), (x + w, y + h), (0, 255, 0), 2)
            label = f"{class_names[cls]}: {score:.2f}"
            cv2.putText(img_with_boxes, label, (x, y - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
        
        # Save result
        output_filename = f"output/trt_output_{img_num}.jpg"
        cv2.imwrite(output_filename, cv2.cvtColor(img_with_boxes, cv2.COLOR_RGB2BGR))
        print(f"Results saved to {output_filename}")
        
        return {
            'original': original_resized,
            'with_boxes': img_with_boxes,
            'results': results,
            'inference_time': inference_time
        }
    
    except Exception as e:
        print(f"Error processing image {img_num}: {e}")
        import traceback
        traceback.print_exc()
        return None

# Initialize TensorRT model (run this once)
print("Loading TensorRT engine...")
trt_model = TRTInference("test.engine")
print("TensorRT engine loaded successfully")

# Create a list of available image files
image_files = [f for f in os.listdir("images") if f.endswith(('.jpg', '.jpeg', '.png'))]
image_numbers = [f.split('.')[0] for f in image_files]
print(f"Available images: {image_numbers}")

# Create a dropdown for image selection
image_dropdown = widgets.Dropdown(
    options=image_numbers,
    description='Image:',
    disabled=False,
)

# Create a button to run inference
run_button = widgets.Button(description='Run Inference')
output_area = widgets.Output()

def on_button_click(b):
    with output_area:
        clear_output()
        img_num = image_dropdown.value
        print(f"Running inference on image {img_num}")
        result = process_image(trt_model, img_num)
        
        if result:
            # Display original and detection side by side
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            axes[0].imshow(result['original'])
            axes[0].set_title('Original Image')
            axes[0].axis('off')
            
            axes[1].imshow(result['with_boxes'])
            axes[1].set_title(f"Detections ({len(result['results'])}) - {result['inference_time']:.2f}ms")
            axes[1].axis('off')
            
            plt.tight_layout()
            plt.show()

run_button.on_click(on_button_click)

# Display the UI components
display(widgets.HBox([image_dropdown, run_button]))
display(output_area)

# For manual usage (alternative to the widgets)
def run_detection(img_num):
    result = process_image(trt_model, img_num)
    
    if result:
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(result['original'])
        plt.title('Original Image')
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(result['with_boxes'])
        plt.title(f"Detections ({len(result['results'])}) - {result['inference_time']:.2f}ms")
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()
        
    return result

# Example usage: run_detection("1")

Loading TensorRT engine...
Number of bindings: 2
Binding 0: name=images, is_input=True, shape=(1, 3, 640, 640), dtype=DataType.FLOAT
Binding 1: name=output0, is_input=False, shape=(1, 25200, 13), dtype=DataType.FLOAT
TensorRT engine loaded successfully
Available images: ['6', '7', '4', '8', '5', '1', '2', '3']


HBox(children=(Dropdown(description='Image:', options=('6', '7', '4', '8', '5', '1', '2', '3'), value='6'), Bu…

Output()