In [1]:
!python --version

Python 3.10.14


In [2]:
!conda list

# packages in environment at /opt/homebrew/Caskroom/miniforge/base/envs/ai_dev:
#
# Name                    Version                   Build  Channel
appnope                   0.1.4              pyhd8ed1ab_0    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
blas                      1.0                    openblas    defaults
brotli-python             1.0.9           py310h313beb8_8    defaults
bzip2                     1.0.8                h80987f9_6    defaults
ca-certificates           2024.8.30            hf0a4a13_0    conda-forge
certifi                   2024.8.30          pyhd8ed1ab_0    conda-forge
charset-normalizer        3.3.2              pyhd3eb1b0_0    defaults
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
debugpy                   1.6.7           py310h313beb8_0    defaults
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
exceptiongroup            1.2.2              pyhd8ed1ab_0    co

In [3]:
import torch
print(torch.__version__)
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())

2.6.0.dev20240927
True
True


In [1]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 检查MPS是否可用
if not torch.backends.mps.is_available():
    raise RuntimeError("MPS is not available. Please check your setup.")

# 设置设备为MPS
device = torch.device("mps")

# 定义简单的神经网络模型


class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


# 初始化模型并移动到MPS设备
model = SimpleNN().to(device)

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)

# 准备数据集和数据加载器
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(
    root='mnist_data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 训练函数


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # 将数据移动到MPS设备
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss_val = loss.item()
            current = batch * len(X)
            print(f"loss: {loss_val:>7f}  [{current:>5d}/{size:>5d}]")


# 执行一次训练过程
train(train_loader, model, loss_fn, optimizer)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%


Extracting mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw

loss: 2.300000  [    0/60000]
loss: 2.291041  [ 6400/60000]
loss: 2.290750  [12800/60000]
loss: 2.283026  [19200/60000]
loss: 2.280827  [25600/60000]
loss: 2.275137  [32000/60000]
loss: 2.273357  [38400/60000]
loss: 2.268939  [44800/60000]
loss: 2.259243  [51200/60000]
loss: 2.251018  [57600/60000]
