# YOLOv8 Strawberry Detection Training - Google Colab Version

This notebook trains a YOLOv8 model for strawberry detection using Google Colab's free GPU.

## Setup

1. Connect to GPU runtime: Runtime → Change runtime type → GPU
2. Mount your Google Drive if needed for dataset access

In [None]:
# Check GPU availability
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Install dependencies
!pip install ultralytics torch torchvision opencv-python matplotlib tqdm tensorboard

## Dataset Setup

Option 1: Upload dataset to Colab (temporary, lost after session ends)
Option 2: Mount Google Drive for persistent storage
Option 3: Download from URL

In [None]:
# Option 2: Mount Google Drive (recommended)
from google.colab import drive
drive.mount('/content/drive')

# Update this path to your dataset location in Google Drive
DATASET_PATH = '/content/drive/MyDrive/strawberry-dataset/straw-detect.v1-straw-detect.yolov8'

# Or use direct path if dataset is uploaded to Colab
# DATASET_PATH = '/content/straw-detect.v1-straw-detect.yolov8'

In [None]:
# Validate dataset structure
import os
import yaml
from pathlib import Path

dataset_path = Path(DATASET_PATH)
data_yaml = dataset_path / 'data.yaml'

if not data_yaml.exists():
    print(f"ERROR: data.yaml not found at {data_yaml}")
else:
    print(f"✓ Found data.yaml at {data_yaml}")
    
    with open(data_yaml, 'r') as f:
        data = yaml.safe_load(f)
    
    print(f"Dataset info:")
    print(f"  Classes: {data['nc']}")
    print(f"  Names: {data['names']}")
    
    # Check image counts
    train_path = dataset_path / data['train']
    val_path = dataset_path / data['val']
    
    if train_path.exists():
        train_images = list(train_path.glob('*.jpg')) + list(train_path.glob('*.png'))
        print(f"  Training images: {len(train_images)}")
    
    if val_path.exists():
        val_images = list(val_path.glob('*.jpg')) + list(val_path.glob('*.png'))
        print(f"  Validation images: {len(val_images)}")

## Training Configuration

In [None]:
# Training parameters
EPOCHS = 100
IMG_SIZE = 640
BATCH_SIZE = 16  # Adjust based on GPU memory
MODEL_NAME = 'yolov8n'  # yolov8n, yolov8s, yolov8m, yolov8l, yolov8x

# Output directories
RESULTS_DIR = '/content/strawberry-results'
WEIGHTS_DIR = '/content/strawberry-weights'

import os
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(WEIGHTS_DIR, exist_ok=True)

print(f"Results will be saved to: {RESULTS_DIR}")
print(f"Weights will be saved to: {WEIGHTS_DIR}")

## Train YOLOv8 Model

In [None]:
from ultralytics import YOLO
import torch

# Load pretrained model
print(f"Loading {MODEL_NAME} model...")
model = YOLO(f'{MODEL_NAME}.pt')

# Train the model
print(f"Starting training for {EPOCHS} epochs...")
results = model.train(
    data=str(data_yaml),
    epochs=EPOCHS,
    imgsz=IMG_SIZE,
    batch=BATCH_SIZE,
    device='0' if torch.cuda.is_available() else 'cpu',
    project=RESULTS_DIR,
    name='strawberry_detection',
    exist_ok=True,
    patience=20,  # Early stopping
    save=True,
    save_period=10,  # Save checkpoint every 10 epochs
    cache=True,
    verbose=True
)

## Save and Export Model

In [None]:
# Save final model
final_model_path = f'{WEIGHTS_DIR}/strawberry_{MODEL_NAME}.pt'
model.save(final_model_path)
print(f"✓ Model saved to: {final_model_path}")

# Export to ONNX format
print("Exporting to ONNX format...")
onnx_path = model.export(format='onnx', imgsz=IMG_SIZE, dynamic=True)
print(f"✓ ONNX model exported to: {onnx_path}")

# Export to TensorFlow Lite format (for Raspberry Pi)
print("Exporting to TensorFlow Lite format...")
tflite_path = model.export(format='tflite', imgsz=IMG_SIZE)
print(f"✓ TFLite model exported to: {tflite_path}")

## View Training Results

In [None]:
# Display training results
import matplotlib.pyplot as plt
from pathlib import Path

results_dir = Path(RESULTS_DIR) / 'strawberry_detection'
plots_dir = results_dir / 'plots'

if plots_dir.exists():
    print("Training plots:")
    for plot_file in plots_dir.glob('*.png'):
        print(f"  - {plot_file.name}")
else:
    print("Plots directory not found yet. Training may still be in progress.")

In [None]:
# Show confusion matrix
from IPython.display import Image, display

confusion_matrix = plots_dir / 'confusion_matrix.png'
if confusion_matrix.exists():
    print("Confusion Matrix:")
    display(Image(filename=str(confusion_matrix)))

## Download Trained Model

Download the trained model to your local machine:

In [None]:
from google.colab import files

# Download PyTorch model
files.download(final_model_path)

# Download ONNX model
files.download(str(onnx_path))

# Download TFLite model
files.download(str(tflite_path))

## Next Steps

1. Copy the trained model to your Raspberry Pi
2. Use the `detect_realtime_pi.py` script for inference
3. Integrate with your robotic arm control system

## Tips for Better Results

- If accuracy is low, increase EPOCHS to 150-200
- Try different MODEL_NAME sizes: yolov8s, yolov8m (slower but more accurate)
- Collect more training images with varied lighting conditions
- Use data augmentation techniques