Running PyTorch on MPS requires MacOS 12.3+ *and* the ARM version of Python installed. We can check the MacOS version with:

In [None]:
import platform; platform.mac_ver()

The first item tells us our MacOS version, this must be greater than *12.3* and if it is not then update your Mac!

The final item tells us the OS version our current environment is running in, this should be *arm64*. An alternative check is to print the `platform`:

In [None]:
platform.platform()

If the above displays something like `macOS-12.4-x86_64-i386-64bit` (eg containing `x86`), we have the wrong version of Python installed and must install the correct (ARM) version. If using Anaconda a new ARM environment can be set up like so:

```bash
CONDA_SUBDIR=osx-arm64 conda create -n ml python=3.9 -c conda-forge
```

Here we are setting the conda version variable to use ARM versions for install. We then `create` a new `conda` environment with name (`-n`) `ml`. We use Python `3.9` for this and make sure we have `conda-forge` as a repository in our channels `-c` where these ARM installs can be downloaded from.

Next we activate the environment with `conda activate ml` and modify the `CONDA_SUBDIR` variable to permanently use `osx-arm64`, otherwise this may default back to an incorrect *x84* version during future installs.

```bash
conda env config vars set CONDA_SUBDIR=osx-arm64
```

You may see a message asking you to reactivate the environment for these changes to take effect, if so just switch back to the *base* environment then back into the `ml` environment with:

```bash
conda activate
conda activate ml
```

Now we're ready to install the latest PyTorch version (v1.12 or higher), at the moment this requires that we install the PyTorch nightly preview with:

```bash
pip install -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```

During downloads you should be able to see something like *Downloading torch-1.1x.x.---**arm64.whl***. That final *arm64.whl* part is important and tells us we are downloading the correct version.

For the examples in this notebook we will also use HF *transformers* and *datasets*.

```bash
pip install transformers datasets
```

---
***Note**: The transformers library uses tokenizers built in Rust (it makes them faster), because we are using this new ARM64 environment we may get **ERROR: Failed building wheel for tokenizers**. If so, we [install Rust](https://huggingface.co/docs/tokenizers/python/v0.9.4/installation/main.html#installation-from-sources) (in the same environment) with:*

```bash
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
```

*And then `pip install transformers datasets` again.*

---

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset

We can check that the MPS device is available with:

In [None]:
torch.has_mps

Awesome! Let's pull some data to test the new MPS-enabled PyTorch.

In [None]:
# load the first 1K rows of the TREC dataset
trec = load_dataset('trec', split='train[:1000]')
trec

In [None]:
trec[0]

Now let's try loading a BERT model, we'll use this and our TREC data to compare inference time on CPU vs MPS.

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')

First we test inference time using a `batch_size` of `64` on the *CPU*.

In [None]:
# take the first 64 rows of the trec data
text = trec['text'][:64]
# tokenize text using the BERT tokenizer
tokens = tokenizer(
    text, max_length=512,
    truncation=True, padding=True,
    return_tensors='pt'
)

In [None]:
%%timeit
model(**tokens)

Now move model and tokens to MPS and try again.

In [None]:
device = torch.device('mps')
model.to(device)
tokens.to(device)
device

In [None]:
%%timeit
model(**tokens)

Not bad, although not as good as the release benchmarks would suggest. We can try with a few different batch sizes and models.

In [None]:
text = trec['text'][:256]
tokens = tokenizer(
    text, max_length=512,
    truncation=True, padding=True,
    return_tensors='pt'
)

device = torch.device('cpu')
model.to(device)
device

In [None]:
%%timeit
model(**tokens)

In [None]:
device = torch.device('mps')
model.to(device)
tokens.to(device)
device

In [None]:
%%timeit
model(**tokens)

Let's try small batches...

In [None]:
text = trec['text'][:8]
tokens = tokenizer(
    text, max_length=512,
    truncation=True, padding=True,
    return_tensors='pt'
)

device = torch.device('cpu')
model.to(device)
device

In [None]:
%%timeit
model(**tokens)

In [None]:
device = torch.device('mps')
model.to(device)
tokens.to(device)
device

In [None]:
%%timeit
model(**tokens)

---

Let's try this with varying batch sizes and tokens.

In [None]:
from time import time

reruns = 6
b = 10

# start with CPU test
device = torch.device('cpu')
model.to(device)

cpu_times = []

for i in range(b):
    text = trec['text'][:2**i]
    tokens = tokenizer(
        text, max_length=512,
        truncation=True, padding=True,
        return_tensors='pt'
    )
    tot_time = 0
    for _ in range(reruns):
        t0 = time()
        model(**tokens)
        tot_time += time()-t0
    cpu_times.append(tot_time/reruns)

# then GPU test
device = torch.device('mps')
model.to(device)

mps_times = []

for i in range(b):
    text = trec['text'][:2**i]
    tokens = tokenizer(
        text, max_length=512,
        truncation=True, padding=True,
        return_tensors='pt'
    ).to(device)
    tot_time = 0
    for _ in range(reruns):
        t0 = time()
        model(**tokens)
        tot_time += time()-t0
    mps_times.append(tot_time/reruns)

In [None]:
import seaborn as sns

sns.set_style('whitegrid')

sns.lineplot(
    x=[2**i for i in range(b)]*2,
    y=cpu_times+mps_times,
    hue=['cpu']*b + ['mps']*b
)

---

In [None]:
cpu_times

In [None]:
mps_times

In [None]:
import numpy
(numpy.array(cpu_times) - numpy.array(mps_times))/numpy.array(cpu_times)