In [None]:
import os
import requests
import zipfile
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import random_split
from ultralytics import YOLO
import matplotlib.pyplot as plt

# ======================================================================
# 1. Download COCO Dataset Directly from Official Website
# ======================================================================
def download_and_extract(url, save_dir, zip_name):
    """Download and extract a ZIP file from a URL."""
    os.makedirs(save_dir, exist_ok=True)
    zip_path = os.path.join(save_dir, zip_name)
    
    # Download with progress bar
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1KB
    progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {zip_name}")
    
    with open(zip_path, 'wb') as f:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data))
            f.write(data)
    progress_bar.close()
    
    # Extract
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(save_dir)
    os.remove(zip_path)  # Delete ZIP after extraction

# COCO Dataset URLs (2017 version)
coco_urls = {
    "train_images": "http://images.cocodataset.org/zips/train2017.zip",
    "val_images": "http://images.cocodataset.org/zips/val2017.zip",
    "annotations": "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
}

# Download and extract all files
data_dir = "data/coco"
for key, url in coco_urls.items():
    if "images" in key:
        save_path = os.path.join(data_dir, "images")
    else:
        save_path = os.path.join(data_dir, "annotations")
    download_and_extract(url, save_path, f"{key}.zip")

# ======================================================================
# 2. Prepare Dataset for Training
# ======================================================================
# Define transformations
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load COCO dataset
dataset = datasets.CocoDetection(
    root=os.path.join(data_dir, "images/train2017"),
    annFile=os.path.join(data_dir, "annotations/instances_train2017.json"),
    transform=transform
)

# Split dataset (80% train, 10% val, 10% test)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

# ======================================================================
# 3. Analyze Class Distribution (Visualization)
# ======================================================================
cat_ids = dataset.coco.getCatIds()
categories = [dataset.coco.loadCats(cat_id)[0]['name'] for cat_id in cat_ids]
category_counts = [len(dataset.coco.getAnnIds(catIds=cat_id)) for cat_id in cat_ids]

plt.figure(figsize=(15, 5))
plt.bar(categories, category_counts)
plt.xticks(rotation=90)
plt.title("COCO Class Distribution")
plt.savefig("class_distribution.png")
plt.show()

# ======================================================================
# 4. Train YOLOv5 Model
# ======================================================================
model = YOLO('yolov5s.pt')  # Load pretrained model

# Train the model
results = model.train(
    data='coco.yaml',  # Config file (auto-detects paths)
    epochs=50,
    imgsz=640,
    batch=16,
    plots=True,  # Save training plots
    val=True     # Validate on val_set
)

# ======================================================================
# 5. Evaluate Model
# ======================================================================
metrics = model.val()  # Evaluate on test_set
print(f"mAP@0.5: {metrics.box.map:.2f}, Precision: {metrics.box.p:.2f}, Recall: {metrics.box.r:.2f}")

# Plot training curves
results_df = pd.DataFrame(results.results_dict)
plt.figure()
plt.plot(results_df['train/loss'], label='Training Loss')
plt.plot(results_df['val/loss'], label='Validation Loss')
plt.title("Training & Validation Loss")
plt.legend()
plt.savefig("loss_curves.png")
plt.show()

# ======================================================================
# 6. Run Inference on a Satellite Image
# ======================================================================

results = model.predict('satellite_image.jpg', conf=0.5)
results[0].save('detection_output.jpg')  

# Print detected objects
print(results[0].pandas().xyxy[0][['name', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']])