# YOLOv8 Microplastics Detection

This notebook will guide you through training a YOLOv8 object detection model to detect microplastics using your dataset.

> **Troubleshooting DataLoader Errors**: If you encounter `DataLoader worker exited unexpectedly` errors, this notebook has been updated to fix these issues by:
> 1. Setting workers=0 to avoid multiprocessing issues
> 2. Reducing batch size to prevent memory overflows
> 3. Using CPU instead of GPU for more stable processing
> 4. Disabling caching to prevent file access conflicts

> **Important Update**: Your dataset is in detection format (bounding boxes), not segmentation format. The notebook has been updated to use YOLOv8 detection instead of segmentation.

In [2]:
# 1. Install Required Libraries
%pip install ultralytics

# The following line will download the YOLOv8 detection model if it doesn't exist
# Uncomment if you need to download it
# !python -c "from ultralytics import YOLO; YOLO('yolov8n.pt')"

Note: you may need to restart the kernel to use updated packages.


## 2. Dataset Structure and Format

Your dataset should be structured as:
- data/train/images/, data/train/labels/
- data/valid/images/, data/valid/labels/
- data/test/images/, data/test/labels/

### Label Format for Object Detection
YOLOv8 detection labels contain normalized bounding box coordinates for each object:
```
<class-id> <x_center> <y_center> <width> <height>
```
Where:
- `class-id`: The object class (0 for microplastics)
- `x_center, y_center`: Normalized center coordinates of the bounding box (0-1)
- `width, height`: Normalized width and height of the bounding box (0-1)

Each object in an image has its own line in the label file. The dataset already contains these labels.

In [None]:
# 3. Training YOLOv8 Detection Model
import os
from pathlib import Path
import torch
from ultralytics import YOLO
import yaml

DATASET_YAML = 'data.yaml'

# 3.0 Data Configuration Validation and Repair
def validate_and_fix_data_yaml(yaml_path):
    """Validate and fix data.yaml file for YOLOv8 compatibility"""
    print(f"Validating and fixing {yaml_path}...")
    try:
        with open(yaml_path, 'r') as f:
            data_cfg = yaml.safe_load(f)
        
        # Make a copy to check if changes were made
        original_cfg = data_cfg.copy()
        
        # Get project root directory (absolute path)
        project_root = os.path.abspath(os.path.dirname(yaml_path))
        
        # Fix fields for YOLOv8 with absolute paths
        data_cfg["path"] = project_root.replace('\\', '/')  # Use forward slashes for paths
        
        # These are relative to the path
        data_cfg["train"] = "data/train/images"
        data_cfg["val"] = "data/valid/images"
        data_cfg["test"] = "data/test/images"
        data_cfg["nc"] = 1
        data_cfg["names"] = ["microplastic"]
        
        # Remove problematic or redundant fields
        keys_to_remove = ["batch", "cache", "workers"]
        for key in keys_to_remove:
            if key in data_cfg:
                del data_cfg[key]
        
        # Check if changes were made
        if data_cfg != original_cfg:
            print("Changes needed in data.yaml. Updating file...")
            with open(yaml_path, 'w') as f:
                yaml.dump(data_cfg, f, default_flow_style=False)
            print("data.yaml has been updated with correct configuration.")
        else:
            print("data.yaml is already correctly configured.")
        
        # Validate absolute paths - important for troubleshooting
        train_path = os.path.join(data_cfg["path"], "data/train/images")
        val_path = os.path.join(data_cfg["path"], "data/valid/images")
        test_path = os.path.join(data_cfg["path"], "data/test/images")
        
        train_path_alt = os.path.join(project_root, "data/train/images")
        val_path_alt = os.path.join(project_root, "data/valid/images")
        test_path_alt = os.path.join(project_root, "data/test/images")
        
        print("\nValidating absolute paths:")
        print(f"Configured path: {data_cfg['path']}")
        
        def check_path(path, name):
            print(f"- Checking {name} path: {path}")
            if os.path.exists(path):
                images = len(list(Path(path).glob('*.jpg'))) + len(list(Path(path).glob('*.png')))
                print(f"  ✓ Path exists! Found {images} images.")
                # Check for corresponding labels
                label_path = path.replace('images', 'labels')
                if os.path.exists(label_path):
                    labels = len(list(Path(label_path).glob('*.txt')))
                    print(f"  ✓ Found {labels} labels.")
                    if labels < images:
                        print(f"  ⚠ WARNING: {images-labels} images may be missing labels!")
                else:
                    print(f"  ✗ WARNING: Label directory {label_path} does not exist!")
                return images > 0
            else:
                print(f"  ✗ ERROR: Directory does not exist!")
                return False
        
        # Check primary paths
        print("\nPrimary configurations:")
        train_ok = check_path(train_path.replace('\\', '/'), "Train")
        val_ok = check_path(val_path.replace('\\', '/'), "Validation")
        test_ok = check_path(test_path.replace('\\', '/'), "Test")
        
        # If paths are missing, check alternative paths
        if not (train_ok and val_ok and test_ok):
            print("\nChecking alternative paths:")
            check_path(train_path_alt.replace('\\', '/'), "Alt Train")
            check_path(val_path_alt.replace('\\', '/'), "Alt Validation")
            check_path(test_path_alt.replace('\\', '/'), "Alt Test")
        
        # Check paths that YOLOv8 might be trying to use (for debugging)
        datasets_path = os.path.join(project_root, "datasets")
        if os.path.exists(datasets_path):
            print(f"\nWARNING: A 'datasets' directory exists at {datasets_path}")
            print("This might be causing path resolution conflicts. YOLOv8 might be looking here instead of your data directory.")
        
        # Check YOLOv8 settings
        settings_path = os.path.expandvars(r"%APPDATA%\Ultralytics\settings.json")
        if os.path.exists(settings_path):
            print(f"\nYOLOv8 settings found at: {settings_path}")
            try:
                import json
                with open(settings_path, 'r') as f:
                    settings = json.load(f)
                if 'datasets_dir' in settings:
                    print(f"YOLOv8 datasets_dir: {settings['datasets_dir']}")
            except:
                print("Could not read YOLOv8 settings file.")
        
        return data_cfg
    except Exception as e:
        print(f"ERROR processing data.yaml: {e}")
        return None

# Run validation and fix
data_cfg = validate_and_fix_data_yaml(DATASET_YAML)

def verify_dataset(yaml_path):
    if not os.path.exists(yaml_path):
        print(f"ERROR: Dataset YAML file not found: {yaml_path}")
        return False
    try:
        with open(yaml_path, 'r') as f:
            data_cfg = yaml.safe_load(f)
    except Exception as e:
        print(f"ERROR: Failed to load YAML: {e}")
        return False
    if 'path' not in data_cfg:
        print("ERROR: 'path' key missing in YAML file.")
        return False
    base_path = Path(data_cfg['path'])
    print(f"Base path: {base_path}")
    all_ok = True
    for split in ['train', 'val', 'test']:
        if split in data_cfg:
            split_path = base_path / data_cfg[split]
            print(f"{split.capitalize()} path: {split_path}")
            if not split_path.exists():
                print(f"WARNING: {split} path does not exist: {split_path}")
                try:
                    os.makedirs(split_path, exist_ok=True)
                    print(f"Created directory: {split_path}")
                except Exception as e:
                    print(f"ERROR: Could not create directory {split_path}: {e}")
                    all_ok = False
            else:
                img_count = len(list(split_path.glob('*.jpg'))) + len(list(split_path.glob('*.png')))
                print(f"  Found {img_count} images in {split} folder")
                if img_count == 0:
                    print(f"WARNING: No images found in {split_path}")
    return all_ok

# Device selection
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print(f"CUDA available: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device('cpu')
    print("No GPU available, using CPU instead.")

print("Verifying dataset paths...")
dataset_ok = verify_dataset(DATASET_YAML)
if not dataset_ok:
    print("Dataset verification failed. Please check your dataset structure and YAML file.")
else:
    # Load YOLOv8 detection model
    try:
        model = YOLO('yolov8n.pt')
    except Exception as e:
        print(f"ERROR: Could not load YOLOv8 model: {e}")
        model = None
    if model is not None:
        try:
            print("Starting training...")
            
            # Create a copy of data.yaml with absolute paths to prevent path resolution issues
            import shutil
            temp_yaml = 'temp_data.yaml'
            shutil.copy(DATASET_YAML, temp_yaml)
            
            # Update the temporary YAML with absolute paths
            with open(temp_yaml, 'r') as f:
                temp_data = yaml.safe_load(f)
            
            # Ensure path is absolute with forward slashes
            project_root = os.path.abspath(os.path.dirname(DATASET_YAML))
            temp_data['path'] = project_root.replace('\\', '/')
            
            with open(temp_yaml, 'w') as f:
                yaml.dump(temp_data, f, default_flow_style=False)
            
            print(f"Temporary data.yaml created with absolute paths:")
            print(f"- path: {temp_data['path']}")
            print(f"- train: {temp_data['train']}")
            print(f"- val: {temp_data['val']}")
            
            # Display actual paths that will be used
            train_path = os.path.join(temp_data['path'], temp_data['train'])
            val_path = os.path.join(temp_data['path'], temp_data['val'])
            print(f"Full train path: {train_path}")
            print(f"Full val path: {val_path}")
            
            #TODO: 1. change epoch to atleast 50 - 100,
            #TODO: 2. change batch size to 16 or 32 (MAX)
            #TODO: 3. change workers to 0-8 (MAX)
            # you might face problems with various parameters, try focusing on finding the most optimal
            # number of workers first.
            results = model.train(
                data=temp_yaml,  # Use the temporary YAML with absolute paths
                epochs=1,             # Increased epochs for better learning
                imgsz=640,             # Input image size
                batch=8,               # Smaller batch size to prevent memory issues
                workers=0,             # Set workers to 0 to avoid DataLoader multiprocessing issues
                mosaic=1.0,            # Mosaic augmentation
                scale=0.5,             # Scale augmentation
                perspective=0.0,       # No perspective augmentation for small objects
                flipud=0.5,            # Flip up-down augmentation
                fliplr=0.5,            # Flip left-right augmentation
                hsv_h=0.015,           # HSV hue augmentation (reduced for consistency)
                hsv_s=0.7,             # HSV saturation augmentation
                hsv_v=0.4,             # HSV value augmentation
                patience=50,           # Early stopping patience
                device='cpu',          # Force CPU to avoid CUDA/DataLoader issues
                project='runs/detect', # Project directory
                name='train',          # Run name
                exist_ok=True,         # Overwrite existing directory
                cache=False,           # Disable cache to prevent potential issues
                amp=False,             # Disable mixed precision to avoid potential issues
                single_cls=True,       # Force single class detection since we only have microplastics
                rect=True              # Rectangular training reduces batch_size by 2, but makes training more stable
            )
            
            # Clean up temporary file
            try:
                os.remove(temp_yaml)
                print(f"Temporary data file {temp_yaml} removed")
            except:
                pass
                
            print("Training completed successfully.")
        except Exception as e:
            print(f"ERROR during training: {e}")

Validating and fixing data.yaml...
data.yaml is already correctly configured.

Validating absolute paths:
Configured path: c:/Users/blasi/CS-ML/FINAL_PROJ

Primary configurations:
- Checking Train path: c:/Users/blasi/CS-ML/FINAL_PROJ/data/train/images
  ✓ Path exists! Found 3226 images.
  ✓ Found 3226 labels.
- Checking Validation path: c:/Users/blasi/CS-ML/FINAL_PROJ/data/valid/images
  ✓ Path exists! Found 928 images.
  ✓ Found 928 labels.
- Checking Test path: c:/Users/blasi/CS-ML/FINAL_PROJ/data/test/images
  ✓ Path exists! Found 453 images.
  ✓ Found 453 labels.

YOLOv8 settings found at: C:\Users\blasi\AppData\Roaming\Ultralytics\settings.json
YOLOv8 datasets_dir: C:\Users\blasi\CS-ML\FINAL_PROJ\datasets
CUDA available: NVIDIA GeForce RTX 3050 Ti Laptop GPU
Verifying dataset paths...
Base path: c:\Users\blasi\CS-ML\FINAL_PROJ
Train path: c:\Users\blasi\CS-ML\FINAL_PROJ\data\train\images
  Found 3226 images in train folder
Val path: c:\Users\blasi\CS-ML\FINAL_PROJ\data\valid\imag

[34m[1mtrain: [0mScanning C:\Users\blasi\CS-ML\FINAL_PROJ\data\train\labels.cache... 3226 images, 0 backgrounds, 0 corrupt: 100%|██████████| 3226/3226 [00:00<?, ?it/s]

[34m[1mval: [0mFast image access  (ping: 0.00.0 ms, read: 366.863.7 MB/s, size: 32.7 KB)
[34m[1mval: [0mFast image access  (ping: 0.00.0 ms, read: 366.863.7 MB/s, size: 32.7 KB)



[34m[1mval: [0mScanning C:\Users\blasi\CS-ML\FINAL_PROJ\data\valid\labels.cache... 928 images, 0 backgrounds, 0 corrupt: 100%|██████████| 928/928 [00:00<?, ?it/s]



Plotting labels to runs\detect\train\labels.jpg... 
[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m AdamW(lr=0.002, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
Image sizes 640 train, 640 val
Using 0 dataloader workers
Logging results to [1mruns\detect\train[0m
Starting training for 1 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
[34m[1moptimizer:[0m AdamW(lr=0.002, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
Image sizes 640 train, 640 val
Using 0 dataloader workers
Logging results to [1mruns\detect\train[0m
Starting training for

        1/1         0G      1.891      2.166      1.429          8        640: 100%|██████████| 404/404 [10:38<00:00,  1.58s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95):   0%|          | 0/58 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 58/58 [01:25<00:00,  1.47s/it]



                   all        928       6475      0.719      0.596      0.654      0.297

1 epochs completed in 0.202 hours.

1 epochs completed in 0.202 hours.
Optimizer stripped from runs\detect\train\weights\last.pt, 6.2MB
Optimizer stripped from runs\detect\train\weights\last.pt, 6.2MB
Optimizer stripped from runs\detect\train\weights\best.pt, 6.2MB

Validating runs\detect\train\weights\best.pt...
Ultralytics 8.3.133  Python-3.11.7 torch-2.5.1+cu121 CPU (AMD Ryzen 7 6800H with Radeon Graphics)
Optimizer stripped from runs\detect\train\weights\best.pt, 6.2MB

Validating runs\detect\train\weights\best.pt...
Ultralytics 8.3.133  Python-3.11.7 torch-2.5.1+cu121 CPU (AMD Ryzen 7 6800H with Radeon Graphics)
Model summary (fused): 72 layers, 3,005,843 parameters, 0 gradients, 8.1 GFLOPs

Model summary (fused): 72 layers, 3,005,843 parameters, 0 gradients, 8.1 GFLOPs

ERROR during training: Dataset 'temp_data.yaml' images not found, missing path 'C:\Users\blasi\CS-ML\FINAL_PROJ\datasets\da

In [4]:
# 3.1 Training Visualization
from IPython.display import display, Image
from pathlib import Path
import time
import os

# Function to display training progress during or after training
def show_training_plots():
    results_path = Path('runs/detect/train')  # Changed from segment to detect
    
    # Check if the directory exists first
    if not results_path.exists():
        print(f"Warning: Results directory not found at {results_path}")
        return
    
    # Results plots
    plots = {
        'Training Loss': results_path / 'results.png',
        'Validation Confusion Matrix': results_path / 'val_confusion_matrix_normalized.png',
        'PR Curve': results_path / 'PR_curve.png'
    }
    
    found_plots = False
    for title, plot_path in plots.items():
        if plot_path.exists():
            found_plots = True
            print(f"\n{title}:")
            try:
                display(Image(str(plot_path)))
            except Exception as e:
                print(f"Error displaying {title}: {str(e)}")
        else:
            print(f"\n{title} plot not found at {plot_path}")
    
    if not found_plots:
        print("No training plots found. Training may not have completed successfully.")


In [5]:
# 3.2 Label Sanity Check
import random
from glob import glob

def check_labels(label_dir, num_samples=5):
    label_files = glob(os.path.join(label_dir, '*.txt'))
    if not label_files:
        print(f"No label files found in {label_dir}")
        return
    print(f"Checking {min(num_samples, len(label_files))} random label files in {label_dir}...")
    for lf in random.sample(label_files, min(num_samples, len(label_files))):
        print(f"\nFile: {os.path.basename(lf)}")
        with open(lf, 'r') as f:
            lines = f.readlines()
            if not lines:
                print("  WARNING: Empty label file!")
            for line in lines:
                parts = line.strip().split()
                if len(parts) != 5:
                    print(f"  WARNING: Malformed line: {line.strip()}")
                else:
                    cls_idx, x, y, w, h = parts
                    print(f"  class: {cls_idx}, x: {x}, y: {y}, w: {w}, h: {h}")
                    try:
                        assert 0 <= float(x) <= 1
                        assert 0 <= float(y) <= 1
                        assert 0 <= float(w) <= 1
                        assert 0 <= float(h) <= 1
                    except:
                        print(f"  WARNING: Coordinates out of range: {line.strip()}")

# Check a few label files from train, val, and test
print("\n--- LABEL SANITY CHECK ---")
for split in ['train', 'valid', 'test']:
    label_dir = os.path.join('data', split, 'labels')
    check_labels(label_dir)
print("--- END LABEL CHECK ---\n")


--- LABEL SANITY CHECK ---
Checking 5 random label files in data\train\labels...

File: b-43-_jpg.rf.12e49b63335756b011829335dd63d188.txt
  class: 0, x: 0.59765625, y: 0.21015625, w: 0.0328125, h: 0.0296875
  class: 0, x: 0.83125, y: 0.2875, w: 0.05, h: 0.059375
  class: 0, x: 0.02265625, y: 0.290625, w: 0.0453125, h: 0.071875
  class: 0, x: 0.18515625, y: 0.290625, w: 0.0421875, h: 0.04375
  class: 0, x: 0.91484375, y: 0.340625, w: 0.0671875, h: 0.06875
  class: 0, x: 0.38359375, y: 0.5171875, w: 0.0671875, h: 0.053125
  class: 0, x: 0.421875, y: 0.80234375, w: 0.0625, h: 0.0484375
  class: 0, x: 0.203125, y: 0.8390625, w: 0.05, h: 0.05625
  class: 0, x: 0.60703125, y: 0.86953125, w: 0.1140625, h: 0.0578125
  class: 0, x: 0.18984375, y: 0.953125, w: 0.0703125, h: 0.090625

File: a-59-_jpg.rf.20eab7733ef1c795a2ebb015b4110ee3.txt
  class: 0, x: 0.58984375, y: 0.33125, w: 0.1859375, h: 0.184375
  class: 0, x: 0.61484375, y: 0.609375, w: 0.1890625, h: 0.190625
  class: 0, x: 0.10625, y: 

In [6]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Define IoU calculation function first to avoid "not defined" error
def calculate_iou(box1, box2):
    """
    Calculate IoU of two normalized bounding boxes in format [x_center, y_center, width, height]
    """
    # Convert from [x_center, y_center, width, height] to [x1, y1, x2, y2]
    box1_x1 = box1[0] - box1[2] / 2
    box1_y1 = box1[1] - box1[3] / 2
    box1_x2 = box1[0] + box1[2] / 2
    box1_y2 = box1[1] + box1[3] / 2
    
    box2_x1 = box2[0] - box2[2] / 2
    box2_y1 = box2[1] - box2[3] / 2
    box2_x2 = box2[0] + box2[2] / 2
    box2_y2 = box2[1] + box2[3] / 2
    
    # Calculate intersection area
    x_left = max(box1_x1, box2_x1)
    y_top = max(box1_y1, box2_y1)
    x_right = min(box1_x2, box2_x2)
    y_bottom = min(box1_y2, box2_y2)
    
    if x_right < x_left or y_bottom < y_top:
        return 0.0
    
    intersection_area = (x_right - x_left) * (y_bottom - y_top)
    
    # Calculate union area
    box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1)
    box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1)
    union_area = box1_area + box2_area - intersection_area
    
    if union_area == 0:
        return 0.0
    
    return intersection_area / union_area

# 4. Evaluate on Test Set (with Confusion Matrix)

# Ensure model is loaded from previous training
if 'model' not in locals() or model is None:
    print("Model not loaded. Please run the training cell first.")
else:
    # Run inference on test images
    test_images_dir = 'data/test/images'
    test_images = list(Path(test_images_dir).glob('*.jpg')) + list(Path(test_images_dir).glob('*.png'))
    if not test_images:
        print(f"No test images found in {test_images_dir}")
    else:
        y_true = []
        y_pred = []
        all_iou = []
        confidence_scores = []
        detection_examples = []
        
        # Set a confidence threshold for predictions
        conf_threshold = 0.25
        
        print(f"Evaluating model on {len(test_images)} test images...")
        
        # Process in batches to be more efficient
        batch_size = 4
        for i in range(0, len(test_images), batch_size):
            batch = test_images[i:i+batch_size]
            batch_paths = [str(p) for p in batch]
            
            # Run batch prediction with low confidence threshold to get all potential predictions
            results = model(batch_paths, conf=0.1, iou=0.5, verbose=False)
            
            for idx, (img_path, result) in enumerate(zip(batch, results)):
                img_name = os.path.basename(img_path)
                
                # Get ground truth labels
                label_path = Path(str(img_path).replace('\\images\\', '\\labels\\').replace('/images/', '/labels/').rsplit('.', 1)[0] + '.txt')
                gt_boxes = []
                if label_path.exists():
                    with open(label_path, 'r') as f:
                        for line in f:
                            parts = line.strip().split()
                            if len(parts) == 5:
                                cls_id, x, y, w, h = map(float, parts)
                                gt_boxes.append({
                                    'class': int(cls_id),
                                    'bbox': [x, y, w, h]  # Normalized coordinates
                                })
                
                # Get model predictions
                pred_boxes = []
                if hasattr(result, 'boxes') and result.boxes is not None:
                    boxes = result.boxes
                    for box_idx in range(len(boxes.cls)):
                        conf = float(boxes.conf[box_idx])
                        if conf >= conf_threshold:  # Filter by confidence
                            cls_id = int(boxes.cls[box_idx])
                            xyxy = boxes.xyxy[box_idx].cpu().numpy()  # Get box in xyxy format
                            
                            # Convert xyxy to normalized xywh for comparison with ground truth
                            # This assumes result.orig_shape contains [height, width]
                            img_h, img_w = result.orig_shape
                            x_center = (xyxy[0] + xyxy[2]) / 2 / img_w
                            y_center = (xyxy[1] + xyxy[3]) / 2 / img_h
                            width = (xyxy[2] - xyxy[0]) / img_w
                            height = (xyxy[3] - xyxy[1]) / img_h
                            
                            pred_boxes.append({
                                'class': cls_id,
                                'conf': conf,
                                'bbox': [x_center, y_center, width, height]  # Normalized coordinates
                            })
                            confidence_scores.append(conf)
                
                # Record true positives and false positives for global metrics
                has_gt = len(gt_boxes) > 0
                has_pred = len(pred_boxes) > 0
                
                y_true.append(1 if has_gt else 0)
                y_pred.append(1 if has_pred else 0)
                
                # Store interesting examples for visualization (mismatches or high-confidence detections)
                if (has_gt and not has_pred) or (has_pred and not has_gt) or (has_pred and has_gt and pred_boxes[0]['conf'] > 0.8):
                    detection_examples.append({
                        'img_path': str(img_path),
                        'gt_boxes': gt_boxes,
                        'pred_boxes': pred_boxes,
                        'type': 'FN' if (has_gt and not has_pred) else 'FP' if (has_pred and not has_gt) else 'TP'
                    })
                
                # Calculate IoU for each ground truth box with best matching prediction
                if has_gt and has_pred:
                    for gt_box in gt_boxes:
                        best_iou = 0
                        # Calculate IoU with each prediction
                        for pred_box in pred_boxes:
                            iou = calculate_iou(gt_box['bbox'], pred_box['bbox'])
                            if iou > best_iou:
                                best_iou = iou
                        all_iou.append(best_iou)
        
        # Calculate metrics
        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, zero_division=0)
        rec = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        
        print(f"\nTest set evaluation (confidence threshold: {conf_threshold}):")
        print(f"  Accuracy:  {acc:.4f}")
        print(f"  Precision: {prec:.4f}")
        print(f"  Recall:    {rec:.4f}")
        print(f"  F1-score:  {f1:.4f}")
        
        if all_iou:
            print(f"  Average IoU: {sum(all_iou)/len(all_iou):.4f}")
        
        if confidence_scores:
            print(f"  Average confidence: {sum(confidence_scores)/len(confidence_scores):.4f}")
            print(f"  Min confidence: {min(confidence_scores):.4f}")
            print(f"  Max confidence: {max(confidence_scores):.4f}")
        
        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred, labels=[0,1])
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["No Microplastic", "Microplastic"])
        disp.plot(cmap=plt.cm.Blues)
        plt.title("Confusion Matrix: Microplastic Detection")
        plt.show()
        
        # Display prediction confidence distribution if available
        if confidence_scores:
            plt.figure(figsize=(10, 5))
            plt.hist(confidence_scores, bins=20, alpha=0.7)
            plt.title("Distribution of Prediction Confidence Scores")
            plt.xlabel("Confidence")
            plt.ylabel("Count")
            plt.axvline(x=conf_threshold, color='r', linestyle='--', label=f'Threshold: {conf_threshold}')
            plt.legend()
            plt.show()
        
        # Visualize a few example detections
        if detection_examples:
            print(f"\nShowing {min(3, len(detection_examples))} detection examples:")
            for i, example in enumerate(detection_examples[:3]):
                img = plt.imread(example['img_path'])
                plt.figure(figsize=(10, 8))
                plt.imshow(img)
                plt.title(f"Example {i+1}: {example['type']} - {'Ground Truth' if example['gt_boxes'] else 'No Ground Truth'} | {'Predicted' if example['pred_boxes'] else 'No Prediction'}")
                
                # Draw ground truth boxes in green
                for box in example['gt_boxes']:
                    x, y, w, h = box['bbox']
                    img_h, img_w = img.shape[:2]
                    rect = plt.Rectangle(
                        ((x - w/2) * img_w, (y - h/2) * img_h),
                        w * img_w, h * img_h,
                        linewidth=2, edgecolor='g', facecolor='none',
                        label='Ground Truth'
                    )
                    plt.gca().add_patch(rect)
                
                # Draw prediction boxes in red
                for box in example['pred_boxes']:
                    x, y, w, h = box['bbox']
                    conf = box.get('conf', 0)
                    img_h, img_w = img.shape[:2]
                    rect = plt.Rectangle(
                        ((x - w/2) * img_w, (y - h/2) * img_h),
                        w * img_w, h * img_h,
                        linewidth=2, edgecolor='r', facecolor='none',
                        label=f'Prediction (conf: {conf:.2f})'
                    )
                    plt.gca().add_patch(rect)
                    plt.annotate(f'{conf:.2f}', ((x - w/2) * img_w, (y - h/2) * img_h - 5), 
                                 color='r', fontsize=12, weight='bold')
                
                plt.legend()
                plt.show()

Evaluating model on 453 test images...

Test set evaluation (confidence threshold: 0.25):
  Accuracy:  1.0000
  Precision: 1.0000
  Recall:    1.0000
  F1-score:  1.0000
  Average IoU: 0.5565
  Average confidence: 0.6340
  Min confidence: 0.2501
  Max confidence: 0.9954

Test set evaluation (confidence threshold: 0.25):
  Accuracy:  1.0000
  Precision: 1.0000
  Recall:    1.0000
  F1-score:  1.0000
  Average IoU: 0.5565
  Average confidence: 0.6340
  Min confidence: 0.2501
  Max confidence: 0.9954


<Figure size 640x480 with 2 Axes>

<Figure size 1000x500 with 1 Axes>


Showing 3 detection examples:


<Figure size 1000x800 with 1 Axes>

<Figure size 1000x800 with 1 Axes>

<Figure size 1000x800 with 1 Axes>

---

## Next Steps
- You can adjust the number of epochs, image size, or model variant (e.g., yolov8m.pt) as needed.
- Run the notebook cells to train and evaluate your model.

## Understanding the Results

### Metrics Explanation
- **mAP (mean Average Precision)**: The primary metric for object detection performance
- **Precision**: How accurate the positive detections are
- **Recall**: The ability of the model to find all microplastics in the image
- **IoU (Intersection over Union)**: Measures how well the predicted bounding boxes overlap with the ground truth

### Model Improvements

To improve model performance:

1. **Try larger models**: Replace `yolov8n.pt` with:
   - `yolov8s.pt` (small) 
   - `yolov8m.pt` (medium)
   - `yolov8l.pt` (large)
   - `yolov8x.pt` (extra large)

2. **Data augmentation**: Add more augmentations to prevent overfitting:
   ```python
   model.train(
       # Other parameters
       augment=True,
       mixup=0.1,
       copy_paste=0.1
   )
   ```

3. **Optimization**: Try different optimizers:
   ```python
   model.train(
       # Other parameters
       optimizer="AdamW",
       lr0=0.001
   )
   ```

4. **Export model**: Save the model for deployment:
   ```python
   model.export(format="onnx") # or "torchscript", "openvino", etc.
   ```

In [7]:
# 5. Visualize Model Predictions
def visualize_predictions(model, image_path, confidence=0.25):
    """Visualize model predictions on a single image with detailed metrics"""
    if not os.path.exists(image_path):
        print(f"Image not found: {image_path}")
        return
    
    # Run inference
    results = model(image_path, conf=confidence)
    result = results[0]
    
    # Get image
    img = plt.imread(image_path)
    
    # Plot the image
    plt.figure(figsize=(12, 8))
    plt.imshow(img)
    plt.title(f"YOLOv8 Detection: {os.path.basename(image_path)}")
    
    # Get the ground truth labels
    label_path = Path(str(image_path).replace('/images/', '/labels/').replace('\\images\\', '\\labels\\').rsplit('.', 1)[0] + '.txt')
    gt_boxes = []
    if label_path.exists():
        with open(label_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 5:
                    cls_id, x, y, w, h = map(float, parts)
                    gt_boxes.append({
                        'class': int(cls_id),
                        'bbox': [x, y, w, h]  # Normalized coordinates
                    })
    
    # Draw ground truth boxes in green
    for box in gt_boxes:
        x, y, w, h = box['bbox']
        img_h, img_w = img.shape[:2]
        rect = plt.Rectangle(
            ((x - w/2) * img_w, (y - h/2) * img_h),
            w * img_w, h * img_h,
            linewidth=2, edgecolor='g', facecolor='none',
            label='Ground Truth'
        )
        plt.gca().add_patch(rect)
    
    # Draw detection boxes
    if hasattr(result, 'boxes') and result.boxes is not None:
        boxes = result.boxes
        for i in range(len(boxes.cls)):
            # Get box coordinates
            box = boxes.xyxy[i].cpu().numpy()
            x1, y1, x2, y2 = box
            conf = boxes.conf[i].item()
            cls_id = int(boxes.cls[i].item())
            
            # Draw rectangle
            rect = plt.Rectangle(
                (x1, y1),
                x2 - x1, y2 - y1,
                linewidth=2, edgecolor='r', facecolor='none',
                label=f'Prediction (conf: {conf:.2f})'
            )
            plt.gca().add_patch(rect)
            
            # Add confidence score
            plt.text(
                x1, y1 - 5,
                f"{conf:.2f}",
                color='white', fontsize=10, bbox=dict(facecolor='red', alpha=0.5)
            )
    
    # Only add unique legend entries
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    plt.legend(by_label.values(), by_label.keys(), loc='upper right')
    
    plt.show()
    
    # Print detection details
    print(f"Found {len(gt_boxes)} ground truth boxes and {0 if not hasattr(result, 'boxes') else len(result.boxes)} detections.")
    
    if hasattr(result, 'boxes') and result.boxes is not None and len(result.boxes) > 0:
        print("\nDetection Details:")
        for i in range(len(result.boxes)):
            conf = result.boxes.conf[i].item()
            cls_id = int(result.boxes.cls[i].item())
            print(f"  Box {i+1}: Class {cls_id} (microplastic), Confidence: {conf:.4f}")

# Function to visualize random test images
def visualize_random_samples(model, num_samples=3, data_path="data/test/images", confidence=0.25):
    """Visualize predictions on random samples from the dataset"""
    image_files = list(Path(data_path).glob('*.jpg')) + list(Path(data_path).glob('*.png'))
    if not image_files:
        print(f"No images found in {data_path}")
        return
        
    print(f"Found {len(image_files)} images in {data_path}")
    print(f"Visualizing {min(num_samples, len(image_files))} random samples...\n")
    
    # Select random samples
    samples = random.sample(image_files, min(num_samples, len(image_files)))
    
    # Visualize each sample
    for img_path in samples:
        visualize_predictions(model, str(img_path), confidence=confidence)
        print("\n" + "-"*50 + "\n")

# Try on a few random test images
if 'model' in locals() and model is not None:
    try:
        # Set a lower confidence threshold to see more potential detections
        visualize_random_samples(model, num_samples=3, confidence=0.1)
    except Exception as e:
        print(f"Error during visualization: {e}")

Found 453 images in data/test/images
Visualizing 3 random samples...



<Figure size 1200x800 with 1 Axes>

Found 4 ground truth boxes and 10 detections.

Detection Details:
  Box 1: Class 0 (microplastic), Confidence: 0.8816
  Box 2: Class 0 (microplastic), Confidence: 0.8204
  Box 3: Class 0 (microplastic), Confidence: 0.7243
  Box 4: Class 0 (microplastic), Confidence: 0.4579
  Box 5: Class 0 (microplastic), Confidence: 0.4028
  Box 6: Class 0 (microplastic), Confidence: 0.3163
  Box 7: Class 0 (microplastic), Confidence: 0.2672
  Box 8: Class 0 (microplastic), Confidence: 0.1938
  Box 9: Class 0 (microplastic), Confidence: 0.1909
  Box 10: Class 0 (microplastic), Confidence: 0.1191

--------------------------------------------------



<Figure size 1200x800 with 1 Axes>

Found 13 ground truth boxes and 23 detections.

Detection Details:
  Box 1: Class 0 (microplastic), Confidence: 0.9194
  Box 2: Class 0 (microplastic), Confidence: 0.8698
  Box 3: Class 0 (microplastic), Confidence: 0.7011
  Box 4: Class 0 (microplastic), Confidence: 0.6959
  Box 5: Class 0 (microplastic), Confidence: 0.6897
  Box 6: Class 0 (microplastic), Confidence: 0.5368
  Box 7: Class 0 (microplastic), Confidence: 0.5021
  Box 8: Class 0 (microplastic), Confidence: 0.4199
  Box 9: Class 0 (microplastic), Confidence: 0.4135
  Box 10: Class 0 (microplastic), Confidence: 0.3224
  Box 11: Class 0 (microplastic), Confidence: 0.3097
  Box 12: Class 0 (microplastic), Confidence: 0.2987
  Box 13: Class 0 (microplastic), Confidence: 0.2771
  Box 14: Class 0 (microplastic), Confidence: 0.2423
  Box 15: Class 0 (microplastic), Confidence: 0.2188
  Box 16: Class 0 (microplastic), Confidence: 0.2118
  Box 17: Class 0 (microplastic), Confidence: 0.2108
  Box 18: Class 0 (microplastic), Confide

<Figure size 1200x800 with 1 Axes>

Found 4 ground truth boxes and 8 detections.

Detection Details:
  Box 1: Class 0 (microplastic), Confidence: 0.9690
  Box 2: Class 0 (microplastic), Confidence: 0.9634
  Box 3: Class 0 (microplastic), Confidence: 0.8262
  Box 4: Class 0 (microplastic), Confidence: 0.8018
  Box 5: Class 0 (microplastic), Confidence: 0.4220
  Box 6: Class 0 (microplastic), Confidence: 0.2429
  Box 7: Class 0 (microplastic), Confidence: 0.1225
  Box 8: Class 0 (microplastic), Confidence: 0.1224

--------------------------------------------------

