# üèåÔ∏è SwingAI Model Test Suite
Upload a golf swing video ‚Üí see all 3 models in action

In [None]:
!pip install -q ultralytics opencv-python-headless yt-dlp
from ultralytics import YOLO
import cv2, os, numpy as np, time
from IPython.display import display, Image, HTML
from google.colab import files
import matplotlib.pyplot as plt
print('‚úÖ Ready')

In [None]:
# Load models - if from same session they're in /content/runs
# Otherwise upload .pt files
shaft_model = None; clubhead_model = None; phase_model = None
for p in ['/content/runs/shaft_seg/weights/best.pt', '/content/shaft_seg.pt']:
    if os.path.exists(p): shaft_model = YOLO(p); break
for p in ['/content/runs/clubhead_det/weights/best.pt', '/content/clubhead_det.pt']:
    if os.path.exists(p): clubhead_model = YOLO(p); break
for p in ['/content/runs/phase_cls/weights/best.pt', '/content/phase_cls.pt']:
    if os.path.exists(p): phase_model = YOLO(p); break
if not shaft_model:
    print('‚ö†Ô∏è Upload your .pt files:'); uploaded = files.upload()
    for name, data in uploaded.items():
        with open(f'/content/{name}', 'wb') as f: f.write(data)
    shaft_model = YOLO('/content/shaft_seg.pt') if os.path.exists('/content/shaft_seg.pt') else None
    clubhead_model = YOLO('/content/clubhead_det.pt') if os.path.exists('/content/clubhead_det.pt') else None
    phase_model = YOLO('/content/phase_cls.pt') if os.path.exists('/content/phase_cls.pt') else None
print(f'Shaft: {"‚úÖ" if shaft_model else "‚ùå"} | Club Head: {"‚úÖ" if clubhead_model else "‚ùå"} | Phase: {"‚úÖ" if phase_model else "‚ùå"}')

In [None]:
# Download sample golf swing video
!yt-dlp -f 'bestvideo[height<=720]+bestaudio/best[height<=720]' --merge-output-format mp4 -o '/content/test_swing.mp4' 'https://www.youtube.com/shorts/dJlGBTsyVzg' --no-playlist 2>&1 | tail -3
VIDEO_PATH = '/content/test_swing.mp4'
if not os.path.exists(VIDEO_PATH): print('‚ö†Ô∏è Download failed - upload your own video in next cell')

In [None]:
# OR upload your own video (uncomment):
# uploaded = files.upload()
# VIDEO_PATH = '/content/' + list(uploaded.keys())[0]

## üîç Single Frame Analysis

In [None]:
cap = cv2.VideoCapture(VIDEO_PATH)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps_vid = cap.get(cv2.CAP_PROP_FPS)
print(f'Video: {total_frames} frames, {fps_vid:.0f} fps, {total_frames/max(fps_vid,1):.1f}s')
sample_indices = np.linspace(0, total_frames-1, 8, dtype=int)
frames = []
for idx in sample_indices:
    cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
    ret, frame = cap.read()
    if ret: frames.append((idx, frame))
cap.release()
print(f'Extracted {len(frames)} frames')

In [None]:
fig, axes = plt.subplots(3, len(frames), figsize=(4*len(frames), 12))
fig.suptitle('SwingAI Model Results', fontsize=16, fontweight='bold')
for i, (idx, frame) in enumerate(frames):
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    if shaft_model:
        r = shaft_model.predict(frame, imgsz=640, conf=0.25, verbose=False)[0]
        axes[0][i].imshow(cv2.cvtColor(r.plot(), cv2.COLOR_BGR2RGB))
        axes[0][i].set_title(f'F{idx}', fontsize=9)
    else: axes[0][i].imshow(rgb)
    axes[0][i].axis('off')
    if i==0: axes[0][i].set_ylabel('Shaft Seg', fontsize=12, fontweight='bold')
    if clubhead_model:
        r = clubhead_model.predict(frame, imgsz=640, conf=0.25, verbose=False)[0]
        axes[1][i].imshow(cv2.cvtColor(r.plot(), cv2.COLOR_BGR2RGB))
        axes[1][i].set_title(f'{len(r.boxes)} det', fontsize=9)
    else: axes[1][i].imshow(rgb)
    axes[1][i].axis('off')
    if i==0: axes[1][i].set_ylabel('Club Head', fontsize=12, fontweight='bold')
    if phase_model:
        r = phase_model.predict(frame, imgsz=224, verbose=False)[0]
        pname = r.names[r.probs.top1]; conf = r.probs.top1conf.item()
        axes[2][i].imshow(rgb)
        axes[2][i].set_title(f'{pname}\n{conf:.0%}', fontsize=9, color='green' if conf>0.7 else 'orange')
    else: axes[2][i].imshow(rgb)
    axes[2][i].axis('off')
    if i==0: axes[2][i].set_ylabel('Phase', fontsize=12, fontweight='bold')
plt.tight_layout(); plt.savefig('/content/model_test_grid.png', dpi=150); plt.show()
print('üìä Saved: /content/model_test_grid.png')

## ‚ö° Speed Benchmark

In [None]:
test_frame = frames[len(frames)//2][1]
N = 50
print('‚ö° Speed Benchmark (50 iterations):')
print('='*55)
for name, model, imgsz in [('Shaft Seg', shaft_model, 640), ('Club Head', clubhead_model, 640), ('Phase Cls', phase_model, 224)]:
    if not model: print(f'{name}: ‚ùå not loaded'); continue
    for _ in range(3): model.predict(test_frame, imgsz=imgsz, verbose=False)
    start = time.time()
    for _ in range(N): model.predict(test_frame, imgsz=imgsz, verbose=False)
    elapsed = time.time() - start
    ms = (elapsed/N)*1000; f = N/elapsed
    r = 'üü¢ 120fps' if ms<8 else 'üü° 60fps' if ms<16 else 'üü† 30fps' if ms<33 else 'üî¥ slow'
    print(f'{name:12s}: {ms:6.1f}ms | {f:5.1f} FPS | {r}')
print('\n‚ö†Ô∏è CoreML on iPhone Neural Engine is typically 2-5x faster than PyTorch on GPU')

## üé¨ Full Video Analysis + Annotated Output

In [None]:
cap = cv2.VideoCapture(VIDEO_PATH)
w=int(cap.get(3)); h=int(cap.get(4)); fps_v=cap.get(5); total=int(cap.get(7))
out = cv2.VideoWriter('/content/analyzed_swing.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps_v, (w,h))
shaft_angles=[]; ch_positions=[]; phases=[]; fidx=0
print(f'Processing {total} frames...')
while True:
    ret, frame = cap.read()
    if not ret: break
    ann = frame.copy()
    if shaft_model:
        r = shaft_model.predict(frame, imgsz=640, conf=0.25, verbose=False)[0]
        if r.masks and len(r.masks.xy)>0:
            for m in r.masks.xy: cv2.fillPoly(ann, [m.astype(np.int32)], (0,255,0)); cv2.polylines(ann,[m.astype(np.int32)],True,(0,255,0),2)
            all_pts = np.vstack(r.masks.xy)
            if len(all_pts)>5:
                vx,vy,x,y = cv2.fitLine(all_pts.astype(np.float32), cv2.DIST_L2,0,0.01,0.01)
                angle = np.degrees(np.arctan2(float(vx),float(vy))); shaft_angles.append(angle)
                cv2.putText(ann,f'Shaft:{angle:.0f}deg',(10,30),cv2.FONT_HERSHEY_SIMPLEX,0.8,(0,255,0),2)
            else: shaft_angles.append(None)
        else: shaft_angles.append(None)
    if clubhead_model:
        r = clubhead_model.predict(frame, imgsz=640, conf=0.25, verbose=False)[0]
        if len(r.boxes)>0:
            b=r.boxes[0]; x1,y1,x2,y2=b.xyxy[0].cpu().numpy().astype(int); cx,cy=(x1+x2)//2,(y1+y2)//2
            cv2.rectangle(ann,(x1,y1),(x2,y2),(0,0,255),2); cv2.circle(ann,(cx,cy),5,(0,0,255),-1)
            ch_positions.append((cx/w,cy/h))
            cv2.putText(ann,f'Club:({cx},{cy})',(10,60),cv2.FONT_HERSHEY_SIMPLEX,0.8,(0,0,255),2)
        else: ch_positions.append(None)
    if phase_model:
        r = phase_model.predict(frame, imgsz=224, verbose=False)[0]
        pn=r.names[r.probs.top1]; pc=r.probs.top1conf.item(); phases.append(pn)
        cv2.putText(ann,f'{pn} ({pc:.0%})',(10,90),cv2.FONT_HERSHEY_SIMPLEX,0.8,(0,255,255),2)
    out.write(ann); fidx+=1
    if fidx%30==0: print(f'  {fidx}/{total}',end='\r')
cap.release(); out.release()
print(f'\n‚úÖ Saved: /content/analyzed_swing.mp4 ({fidx} frames)')

In [None]:
# Plot analysis graphs
fig,(ax1,ax2,ax3)=plt.subplots(3,1,figsize=(14,10),sharex=True)
fig.suptitle('Swing Metrics Over Time',fontsize=14,fontweight='bold')
angles=[a if a else np.nan for a in shaft_angles]
ax1.plot(angles,color='green',lw=1.5); ax1.set_ylabel('Shaft Angle (¬∞)'); ax1.grid(True,alpha=0.3)
if ch_positions:
    ys=[1-p[1] if p else np.nan for p in ch_positions]
    ax2.plot(ys,color='red',lw=1.5); ax2.set_ylabel('Club Height'); ax2.grid(True,alpha=0.3)
    speeds=[0]
    for i in range(1,len(ch_positions)):
        if ch_positions[i] and ch_positions[i-1]:
            dx=ch_positions[i][0]-ch_positions[i-1][0]; dy=ch_positions[i][1]-ch_positions[i-1][1]
            speeds.append(np.sqrt(dx**2+dy**2)*fps_v)
        else: speeds.append(np.nan)
    ax3.plot(speeds,color='blue',lw=1.5); ax3.set_ylabel('Club Speed'); ax3.grid(True,alpha=0.3)
    pk=np.nanargmax(speeds); ax3.axvline(x=pk,color='red',ls='--',alpha=0.7,label=f'Peak@F{pk}'); ax3.legend()
ax3.set_xlabel('Frame')
plt.tight_layout(); plt.savefig('/content/swing_graphs.png',dpi=150); plt.show()
print('üìä Saved: /content/swing_graphs.png')

In [None]:
# Download all results
try:
    files.download('/content/analyzed_swing.mp4')
    files.download('/content/model_test_grid.png')
    files.download('/content/swing_graphs.png')
except: print('Download from file browser ‚Üí')