# Interactive PyTorch on Apple Silicon (MPS)

This notebook guides you through installing and using PyTorch with Apple's Metal Performance Shaders (MPS) backend on Apple Silicon (M1/M2/M3). It includes environment setup, verification, and interactive examples.

## Table of Contents
1. [Prerequisites](#prereqs)
2. [Create a Conda Environment](#env)
3. [Install PyTorch with MPS Support](#install)
4. [Verify MPS Availability](#verify)
5. [Basic MPS Operations](#basic)
6. [CPU vs MPS Performance (Quick Check)](#perf)
7. [Troubleshooting & Tips](#troubleshooting)

## 1. Prerequisites <a name="prereqs"></a>

- macOS 12.3 (Monterey) or newer.
- Apple Silicon (M1/M2/M3) machine.
- Xcode Command Line Tools: `xcode-select --install` (run in Terminal, if not already installed).
- Conda or Mamba (recommended for clean envs). If not installed, install [Miniforge](https://github.com/conda-forge/miniforge) for native arm64 builds.

## 2. Create a Conda Environment <a name="env"></a>

Use an arm64-native conda (Miniforge/Mambaforge) to create an environment:

- Option A (mamba):
```bash
mamba create -n torch-mps python=3.11 -y
mamba activate torch-mps
```
- Option B (conda):
```bash
conda create -n torch-mps python=3.11 -y
conda activate torch-mps
```
After activating, return to this notebook with the env's kernel (use `python -m ipykernel install --user --name torch-mps --display-name "Python (torch-mps)"`).

## 3. Install PyTorch with MPS Support <a name="install"></a>

Recent stable PyTorch wheels include MPS by default on macOS arm64. You can install with `pip` (recommended on macOS) or `conda` (conda-forge).

- Option A (pip):
```bash
pip install --upgrade pip
pip install torch torchvision torchaudio
```
- Option B (conda, conda-forge):
```bash
conda install -c conda-forge pytorch torchvision torchaudio -y
```
Note: No CUDA needed; MPS is Apple's GPU backend.

## 4. Verify MPS Availability <a name="verify"></a>

In [None]:
import sys, platform
import torch

print(f'PyTorch version: {torch.__version__}')
print(f'Python: {sys.version.split()[0]} on {platform.system()} {platform.release()} ({platform.machine()})')

has_mps = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
built_mps = hasattr(torch.backends, 'mps') and torch.backends.mps.is_built()
print(f'MPS available: {has_mps}')
print(f'MPS built:     {built_mps}')

if has_mps:
    print('✅ MPS is available. Your Apple GPU can accelerate PyTorch.')
else:
    print('❌ MPS not available. Falling back to CPU. See Troubleshooting below.')


## 5. Basic MPS Operations <a name="basic"></a>

This demonstrates selecting the best available device and doing basic tensor ops.

In [None]:
def get_best_device():
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device('mps')
    elif torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

device = get_best_device()
device


In [None]:
import time

def _sync(device):
    if device.type == 'cuda':
        torch.cuda.synchronize()
    elif device.type == 'mps':
        try:
            torch.mps.synchronize()
        except AttributeError:
            pass  # older PyTorch without explicit mps.synchronize

def basic_ops(device):
    print(f'Using device: {device}')
    a = torch.randn(2048, 2048, device=device)
    b = torch.randn(2048, 2048, device=device)

    # Warm-up (especially important for MPS to initialize pipelines)
    _ = a @ b
    _sync(device)

    t0 = time.time()
    c = a @ b
    _sync(device)
    t1 = time.time()

    print(f'Matmul completed in {t1 - t0:.4f}s on {device}')
    return c

_ = basic_ops(device)


## 6. CPU vs MPS Performance (Quick Check) <a name="perf"></a>

Rough comparison for a single operation and moderate sizes. Results vary by model and workload.

In [None]:
def time_matmul(size, device):
    a = torch.randn(size, size, device=device)
    b = torch.randn(size, size, device=device)
    _ = a @ b  # warm-up
    _sync(device)
    t0 = time.time()
    _ = a @ b
    _sync(device)
    return time.time() - t0

sizes = [1024, 2048, 3072]
cpu_times, mps_times = [], []

cpu = torch.device('cpu')
for s in sizes:
    cpu_times.append(time_matmul(s, cpu))

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    mps = torch.device('mps')
    for s in sizes:
        mps_times.append(time_matmul(s, mps))

print('Size	CPU (s)	MPS (s)	Speedup (CPU/MPS)')
print('-' * 50)
for i, s in enumerate(sizes):
    mps_t = mps_times[i] if len(mps_times) == len(sizes) else float('nan')
    speedup = (cpu_times[i]/mps_t) if mps_t and mps_t == mps_t and mps_t > 0 else float('nan')
    print(f'{s}x{s}	{cpu_times[i]:.4f}	{mps_t:.4f}	{speedup:.2f}x')


## 7. Troubleshooting & Tips <a name="troubleshooting"></a>

- If `MPS available: False` but `MPS built: True`, ensure you're on macOS ≥ 12.3 and using an arm64-native Python (Miniforge).
- If you installed Intel x86_64 Python under Rosetta, reinstall Miniforge (arm64) and recreate the env.
- Update PyTorch to the latest stable: `pip install -U torch torchvision torchaudio`.
- Some ops may be slower or not implemented on MPS in certain versions. Consider smaller batch sizes or fallback to CPU selectively.
- For timing on MPS, repeated warm-ups can improve stability.

### Selecting device in your training scripts
```python
device = (
    torch.device('mps') if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
    else torch.device('cuda') if torch.cuda.is_available()
    else torch.device('cpu')
)
```

### References
- PyTorch MPS docs: https://pytorch.org/docs/stable/notes/mps.html
- Miniforge (conda-forge arm64): https://github.com/conda-forge/miniforge