# üèåÔ∏è SwingAI Training v3
Set Runtime ‚Üí A100, then Run All (‚åò+F9)

Fixes: label cleaning + keep-alive

In [None]:
# Keep-alive: prevents idle timeout
import threading, time, IPython
def keep_alive():
    while True:
        time.sleep(120)
        IPython.display.clear_output(wait=True)
        print(f'‚è∞ Keep-alive ping: {time.strftime("%H:%M:%S")}')
threading.Thread(target=keep_alive, daemon=True).start()
print('‚úÖ Keep-alive started (pings every 2 min)')

In [None]:
!pip install -q ultralytics roboflow coremltools
import torch, os, shutil, glob
print(f'GPU: {torch.cuda.get_device_name(0)}' if torch.cuda.is_available() else 'NO GPU!')

In [None]:
ROBOFLOW_API_KEY = "K87z18c01IclZksvIN4F"

In [None]:
from roboflow import Roboflow
rf = Roboflow(api_key=ROBOFLOW_API_KEY)
print('üì• Downloading Segmentation Batch 13...')
ds1 = rf.workspace('fp-srwrm').project('segmentation-batch-13-sequence').version(2).download('yolov8', location='/content/shaft_ds1')
try:
    rf.workspace('bosharluke').project('golf-swing-analyzer-dtl').version(1).download('yolov8', location='/content/shaft_ds3')
except:
    print('‚ö†Ô∏è DTL not accessible')
SHAFT_DATA = '/content/shaft_ds1'
for split in ['train', 'valid', 'test']:
    os.makedirs(os.path.join(SHAFT_DATA, split, 'images'), exist_ok=True)
    os.makedirs(os.path.join(SHAFT_DATA, split, 'labels'), exist_ok=True)
for extra in ['/content/shaft_ds3']:
    if os.path.exists(extra):
        for split in ['train', 'valid', 'test']:
            si = os.path.join(extra, split, 'images')
            sl = os.path.join(extra, split, 'labels')
            di = os.path.join(SHAFT_DATA, split, 'images')
            dl = os.path.join(SHAFT_DATA, split, 'labels')
            if os.path.exists(si):
                for f in glob.glob(os.path.join(si, '*')): shutil.copy2(f, di)
            if os.path.exists(sl):
                for f in glob.glob(os.path.join(sl, '*')): shutil.copy2(f, dl)
        print(f'‚úÖ Merged {extra}')
tc = len(glob.glob(os.path.join(SHAFT_DATA, 'train', 'images', '*')))
vc = len(glob.glob(os.path.join(SHAFT_DATA, 'valid', 'images', '*')))
print(f'üìä Shaft: {tc} train + {vc} val')

In [None]:
# Clean corrupt segmentation labels (box without segment points)
def clean_seg_labels(data_dir):
    fixed = 0
    for split in ['train', 'valid', 'test']:
        label_dir = os.path.join(data_dir, split, 'labels')
        if not os.path.exists(label_dir): continue
        for lf in glob.glob(os.path.join(label_dir, '*.txt')):
            with open(lf, 'r') as f:
                lines = f.readlines()
            clean = [l for l in lines if len(l.strip().split()) >= 7]
            if len(clean) != len(lines):
                fixed += len(lines) - len(clean)
                with open(lf, 'w') as f:
                    f.writelines(clean)
    return fixed

n = clean_seg_labels(SHAFT_DATA)
print(f'üßπ Cleaned {n} corrupt label lines')

In [None]:
from ultralytics import YOLO
print('üèãÔ∏è Training Shaft Segmentation...')
shaft_model = YOLO('yolov8m-seg.pt')
shaft_model.train(data=os.path.join(SHAFT_DATA, 'data.yaml'), epochs=100, imgsz=640, batch=16, device=0, project='/content/runs', name='shaft_seg', patience=15, augment=True, mosaic=1.0, flipud=0.5, fliplr=0.5, degrees=15, scale=0.5, workers=4)
print('‚úÖ Shaft done!')

In [None]:
print('üì• Downloading club head datasets...')
CLUBHEAD_DATA = None
for ws, proj, ver, loc in [('salo-levy','golf-driver-tracker',3,'/content/ch1'),('swingmentor','golf-0okbs',9,'/content/ch2'),('trungam','golf-vfa',2,'/content/ch3')]:
    try:
        rf.workspace(ws).project(proj).version(ver).download('yolov8', location=loc)
        if CLUBHEAD_DATA is None:
            CLUBHEAD_DATA = loc
        else:
            for split in ['train','valid','test']:
                si=os.path.join(loc,split,'images'); sl=os.path.join(loc,split,'labels')
                di=os.path.join(CLUBHEAD_DATA,split,'images'); dl=os.path.join(CLUBHEAD_DATA,split,'labels')
                os.makedirs(di,exist_ok=True); os.makedirs(dl,exist_ok=True)
                if os.path.exists(si):
                    for f in glob.glob(os.path.join(si,'*')): shutil.copy2(f,di)
                if os.path.exists(sl):
                    for f in glob.glob(os.path.join(sl,'*')): shutil.copy2(f,dl)
    except:
        print(f'‚ö†Ô∏è {proj} not accessible')
if CLUBHEAD_DATA:
    print(f'üìä Club head: {len(glob.glob(os.path.join(CLUBHEAD_DATA,"train","images","*")))} train')
else:
    print('‚ùå No club head data')

In [None]:
if CLUBHEAD_DATA:
    print('üèãÔ∏è Training Club Head Detector...')
    ch_model = YOLO('yolov8m.pt')
    ch_model.train(data=os.path.join(CLUBHEAD_DATA, 'data.yaml'), epochs=100, imgsz=640, batch=32, device=0, project='/content/runs', name='clubhead_det', patience=15, augment=True, mosaic=1.0, workers=4)
    print('‚úÖ Club head done!')

In [None]:
print('üì• Downloading phase dataset...')
PHASE_DATA = None
try:
    rf.workspace('container-number-dectection').project('golf_swing_phases_8-mrk0i').version(1).download('folder', location='/content/phase_ds')
    PHASE_DATA = '/content/phase_ds'
except:
    print('‚ö†Ô∏è Phase dataset not accessible')
if PHASE_DATA:
    print(f'üìä Phase dataset ready')

In [None]:
if PHASE_DATA:
    print('üèãÔ∏è Training Phase Classifier...')
    p_model = YOLO('yolov8m-cls.pt')
    p_model.train(data=PHASE_DATA, epochs=100, imgsz=224, batch=64, device=0, project='/content/runs', name='phase_cls', patience=15, workers=4)
    print('‚úÖ Phase classifier done!')

In [None]:
# Export to CoreML + package
models = {'shaft_seg':'/content/runs/shaft_seg/weights/best.pt','clubhead_det':'/content/runs/clubhead_det/weights/best.pt','phase_cls':'/content/runs/phase_cls/weights/best.pt'}
!mkdir -p /content/swingai_models
for name, path in models.items():
    if os.path.exists(path):
        print(f'üì± Exporting {name}...')
        m = YOLO(path)
        m.export(format='coreml', nms=True)
        shutil.copy2(path, f'/content/swingai_models/{name}.pt')
        mlpkg = path.replace('.pt','.mlpackage')
        if os.path.exists(mlpkg):
            shutil.copytree(mlpkg, f'/content/swingai_models/{name}.mlpackage', dirs_exist_ok=True)
        print(f'  ‚úÖ {name}')
!cd /content && zip -r swingai_models.zip swingai_models/
!ls -lh /content/swingai_models/
print('\nüéâ Done! Download swingai_models.zip')
try:
    from google.colab import files
    files.download('/content/swingai_models.zip')
except:
    pass