# üèåÔ∏è SwingAI ML Training Pipeline
**3 models for iOS golf swing analysis**

| Model | Task | Output |
|---|---|---|
| Shaft Segmentation | Instance Seg | Shaft mask ‚Üí angle extraction |
| Club Head Detector | Object Det | Club head bbox ‚Üí speed calc |
| Phase Classifier | Classification | P1-P8 swing phase |

**Runtime:** ~2 hours on A100 | **Output:** 3x `.mlpackage` CoreML models

‚ö†Ô∏è **Set Runtime to GPU A100:** Runtime ‚Üí Change runtime type ‚Üí A100

In [None]:
# ============================================
# CELL 1: Install dependencies
# ============================================
!pip install -q ultralytics roboflow coremltools

import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')

In [None]:
# ============================================
# CELL 2: Set your Roboflow API key
# ============================================
# Get your free API key at: https://app.roboflow.com/settings/api
ROBOFLOW_API_KEY = "YOUR_API_KEY_HERE"  # <-- PASTE YOUR KEY

---
## üîß Model 1: Shaft Segmentation
**Datasets:** Segmentation Batch 13 (9.6K) + Golf Segmentation (3.2K) + Golf-Swing-Analyzer-DTL (1.5K)

**Output:** Pixel mask of club shaft ‚Üí extract angle with cv2.fitLine()

In [None]:
# ============================================
# CELL 3: Download shaft segmentation datasets
# ============================================
from roboflow import Roboflow
import os, shutil

rf = Roboflow(api_key=ROBOFLOW_API_KEY)

# Dataset 1: Segmentation Batch 13 (9.6K images, shaft/arm/body masks)
print('üì• Downloading Segmentation Batch 13...')
ds1 = rf.workspace('fp-srwrm').project('segmentation-batch-13-sequence').version(2).download('yolov8', location='/content/shaft_ds1')

# Dataset 2: Golf Segmentation (3.2K images, hand/head/shaft)
print('üì• Downloading Golf Segmentation...')
try:
    ds2 = rf.workspace('golf-bpbfr').project('golf-segmentation-oocvp').version(1).download('yolov8', location='/content/shaft_ds2')
except:
    print('‚ö†Ô∏è Golf Segmentation dataset not accessible ‚Äî continuing with primary dataset')
    ds2 = None

# Dataset 3: Golf-Swing-Analyzer-DTL (1.5K images, hand/head/shaft)
print('üì• Downloading Golf-Swing-Analyzer-DTL...')
try:
    ds3 = rf.workspace('bosharluke').project('golf-swing-analyzer-dtl').version(1).download('yolov8', location='/content/shaft_ds3')
except:
    print('‚ö†Ô∏è DTL dataset not accessible ‚Äî continuing with available datasets')
    ds3 = None

print('‚úÖ Shaft datasets downloaded')

In [None]:
# ============================================
# CELL 4: Merge shaft datasets
# ============================================
import yaml, glob

# Use primary dataset as base
SHAFT_DATA = '/content/shaft_ds1'

# Merge additional datasets if available
for extra_ds in ['/content/shaft_ds2', '/content/shaft_ds3']:
    if os.path.exists(extra_ds):
        for split in ['train', 'valid', 'test']:
            src_imgs = os.path.join(extra_ds, split, 'images')
            src_lbls = os.path.join(extra_ds, split, 'labels')
            dst_imgs = os.path.join(SHAFT_DATA, split, 'images')
            dst_lbls = os.path.join(SHAFT_DATA, split, 'labels')
            if os.path.exists(src_imgs):
                for f in glob.glob(os.path.join(src_imgs, '*')):
                    shutil.copy2(f, dst_imgs)
                for f in glob.glob(os.path.join(src_lbls, '*')):
                    shutil.copy2(f, dst_lbls)
        print(f'‚úÖ Merged {extra_ds}')

# Count total images
train_imgs = len(glob.glob(os.path.join(SHAFT_DATA, 'train', 'images', '*')))
val_imgs = len(glob.glob(os.path.join(SHAFT_DATA, 'valid', 'images', '*')))
print(f'üìä Shaft dataset: {train_imgs} train + {val_imgs} val images')

In [None]:
# ============================================
# CELL 5: Train Shaft Segmentation Model
# ============================================
from ultralytics import YOLO

print('üèãÔ∏è Training Shaft Segmentation (YOLOv8m-seg)...')
print('‚è±Ô∏è Estimated time: ~60-90 min on A100')

shaft_model = YOLO('yolov8m-seg.pt')

shaft_results = shaft_model.train(
    data=os.path.join(SHAFT_DATA, 'data.yaml'),
    epochs=100,
    imgsz=640,
    batch=32,
    device=0,
    project='/content/runs',
    name='shaft_seg',
    patience=15,       # early stopping
    augment=True,
    mosaic=1.0,
    flipud=0.5,
    fliplr=0.5,
    degrees=15,
    scale=0.5,
    verbose=True
)

print('‚úÖ Shaft segmentation training complete!')
print(f'Best model: /content/runs/shaft_seg/weights/best.pt')

---
## üéØ Model 2: Club Head Detector
**Datasets:** Golf Driver Tracker (2.7K) + SwingMentor Golf (8.5K) + Golf VFA (2.8K)

**Output:** Club head bounding box ‚Üí track position ‚Üí calculate speed

In [None]:
# ============================================
# CELL 6: Download club head datasets
# ============================================
print('üì• Downloading Golf Driver Tracker...')
try:
    ch1 = rf.workspace('salo-levy').project('golf-driver-tracker').version(3).download('yolov8', location='/content/clubhead_ds1')
except:
    print('‚ö†Ô∏è Golf Driver Tracker not accessible')

print('üì• Downloading SwingMentor Golf...')
try:
    ch2 = rf.workspace('swingmentor').project('golf-0okbs').version(9).download('yolov8', location='/content/clubhead_ds2')
except:
    print('‚ö†Ô∏è SwingMentor not accessible')

print('üì• Downloading Golf VFA...')
try:
    ch3 = rf.workspace('trungam').project('golf-vfa').version(2).download('yolov8', location='/content/clubhead_ds3')
except:
    print('‚ö†Ô∏è Golf VFA not accessible')

print('‚úÖ Club head datasets downloaded')

In [None]:
# ============================================
# CELL 7: Merge club head datasets
# ============================================
CLUBHEAD_DATA = '/content/clubhead_ds1'

# Use first available dataset as base, merge others
base_found = False
for ds_path in ['/content/clubhead_ds1', '/content/clubhead_ds2', '/content/clubhead_ds3']:
    if os.path.exists(ds_path):
        if not base_found:
            CLUBHEAD_DATA = ds_path
            base_found = True
            continue
        for split in ['train', 'valid', 'test']:
            src_imgs = os.path.join(ds_path, split, 'images')
            src_lbls = os.path.join(ds_path, split, 'labels')
            dst_imgs = os.path.join(CLUBHEAD_DATA, split, 'images')
            dst_lbls = os.path.join(CLUBHEAD_DATA, split, 'labels')
            if os.path.exists(src_imgs):
                os.makedirs(dst_imgs, exist_ok=True)
                os.makedirs(dst_lbls, exist_ok=True)
                for f in glob.glob(os.path.join(src_imgs, '*')):
                    shutil.copy2(f, dst_imgs)
                for f in glob.glob(os.path.join(src_lbls, '*')):
                    shutil.copy2(f, dst_lbls)
        print(f'‚úÖ Merged {ds_path}')

train_imgs = len(glob.glob(os.path.join(CLUBHEAD_DATA, 'train', 'images', '*')))
val_imgs = len(glob.glob(os.path.join(CLUBHEAD_DATA, 'valid', 'images', '*')))
print(f'üìä Club head dataset: {train_imgs} train + {val_imgs} val images')

In [None]:
# ============================================
# CELL 8: Train Club Head Detector
# ============================================
print('üèãÔ∏è Training Club Head Detector (YOLOv8m)...')
print('‚è±Ô∏è Estimated time: ~45-60 min on A100')

clubhead_model = YOLO('yolov8m.pt')

clubhead_results = clubhead_model.train(
    data=os.path.join(CLUBHEAD_DATA, 'data.yaml'),
    epochs=100,
    imgsz=640,
    batch=32,
    device=0,
    project='/content/runs',
    name='clubhead_det',
    patience=15,
    augment=True,
    mosaic=1.0,
    verbose=True
)

print('‚úÖ Club head detector training complete!')
print(f'Best model: /content/runs/clubhead_det/weights/best.pt')

---
## üìä Model 3: Phase Classifier (SwingNet Boost)
**Datasets:** Golf_Swing_Phases_8 (4.6K) + golf swing 2 (6.8K)

**Output:** P1-P8 classification per frame

In [None]:
# ============================================
# CELL 9: Download phase classifier datasets
# ============================================
print('üì• Downloading Golf_Swing_Phases_8...')
try:
    pc1 = rf.workspace('container-number-dectection').project('golf_swing_phases_8-mrk0i').version(1).download('folder', location='/content/phase_ds1')
except:
    print('‚ö†Ô∏è Phases dataset not accessible')

print('üì• Downloading golf swing 2...')
try:
    pc2 = rf.workspace('pose-7amrv').project('golf-swing-2-bycnn').version(1).download('folder', location='/content/phase_ds2')
except:
    print('‚ö†Ô∏è Golf swing 2 not accessible')

print('‚úÖ Phase classifier datasets downloaded')

In [None]:
# ============================================
# CELL 10: Train Phase Classifier
# ============================================
print('üèãÔ∏è Training Phase Classifier (YOLOv8m-cls)...')
print('‚è±Ô∏è Estimated time: ~30-45 min on A100')

# Find the dataset that downloaded successfully
PHASE_DATA = None
for p in ['/content/phase_ds1', '/content/phase_ds2']:
    if os.path.exists(p):
        PHASE_DATA = p
        break

if PHASE_DATA:
    phase_model = YOLO('yolov8m-cls.pt')

    phase_results = phase_model.train(
        data=PHASE_DATA,
        epochs=100,
        imgsz=224,
        batch=64,
        device=0,
        project='/content/runs',
        name='phase_cls',
        patience=15,
        verbose=True
    )
    print('‚úÖ Phase classifier training complete!')
    print(f'Best model: /content/runs/phase_cls/weights/best.pt')
else:
    print('‚ùå No phase dataset available')

---
## üì± Export to CoreML for iOS

In [None]:
# ============================================
# CELL 11: Export all models to CoreML
# ============================================
import glob

models_to_export = {
    'shaft_seg': '/content/runs/shaft_seg/weights/best.pt',
    'clubhead_det': '/content/runs/clubhead_det/weights/best.pt',
    'phase_cls': '/content/runs/phase_cls/weights/best.pt',
}

exported = []
for name, path in models_to_export.items():
    if os.path.exists(path):
        print(f'üì± Exporting {name} to CoreML...')
        model = YOLO(path)
        model.export(format='coreml', nms=True)
        # Find the exported mlpackage
        mlpkg = path.replace('.pt', '.mlpackage')
        if os.path.exists(mlpkg):
            exported.append((name, mlpkg))
            print(f'  ‚úÖ {name} ‚Üí {mlpkg}')
    else:
        print(f'  ‚è≠Ô∏è Skipping {name} (not trained)')

print(f'\nüéâ Exported {len(exported)} CoreML models!')

In [None]:
# ============================================
# CELL 12: Package & download all models
# ============================================
!mkdir -p /content/swingai_models

# Copy all CoreML models to one folder
for name, mlpkg in exported:
    dst = f'/content/swingai_models/{name}.mlpackage'
    if os.path.exists(mlpkg):
        shutil.copytree(mlpkg, dst, dirs_exist_ok=True)

# Also copy the .pt files for future fine-tuning
for name, path in models_to_export.items():
    if os.path.exists(path):
        shutil.copy2(path, f'/content/swingai_models/{name}.pt')

# Zip everything
!cd /content && zip -r swingai_models.zip swingai_models/

print('\nüì¶ All models packaged!')
print('Download: /content/swingai_models.zip')
print('\nModels included:')
!ls -lh /content/swingai_models/

# Auto-download in Colab
try:
    from google.colab import files
    files.download('/content/swingai_models.zip')
except:
    print('\nüí° Download manually from the file browser on the left')

---
## üìä Training Results Summary

In [None]:
# ============================================
# CELL 13: Print results summary
# ============================================
print('=' * 60)
print('üèåÔ∏è SwingAI Training Results')
print('=' * 60)

for name, path in models_to_export.items():
    results_csv = path.replace('weights/best.pt', 'results.csv')
    if os.path.exists(results_csv):
        import pandas as pd
        df = pd.read_csv(results_csv)
        df.columns = df.columns.str.strip()
        print(f'\nüìä {name}:')
        if 'metrics/mAP50(B)' in df.columns:
            best_map = df['metrics/mAP50(B)'].max()
            print(f'   Best mAP50: {best_map:.3f}')
        if 'metrics/mAP50-95(B)' in df.columns:
            best_map95 = df['metrics/mAP50-95(B)'].max()
            print(f'   Best mAP50-95: {best_map95:.3f}')
        if 'metrics/accuracy_top1' in df.columns:
            best_acc = df['metrics/accuracy_top1'].max()
            print(f'   Best Top-1 Accuracy: {best_acc:.3f}')
        print(f'   Epochs trained: {len(df)}')

print('\n' + '=' * 60)
print('\nüöÄ Next steps:')
print('1. Download swingai_models.zip')
print('2. Unzip and drag .mlpackage files into Xcode')
print('3. Add inference code to SwingAI app')
print('4. Ship it! üèåÔ∏è‚Äç‚ôÇÔ∏è')