Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Placeholder storage has not been allocated on MPS device! #90440

Open
collindbell opened this issue Dec 8, 2022 · 5 comments
Open
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@collindbell
Copy link

collindbell commented Dec 8, 2022

🐛 Describe the bug

I get an error every time I attempt to use MPS to train a model on my M1 Mac. The error occurs at first training step (so first call of model(x)). MRE:

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

import pandas as pd
import numpy as np

device = torch.device('mps')

class MyLSTM(nn.Module):
    def __init__(self, hidden_size, num_layers, output_size, input_dim):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.input_dim = input_dim

        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)

        out, _ = self.lstm(x, (h0, c0))

        out = self.fc(out[:, -1, :])
        return out

def train_step(model, criterion, optimizer, x, y):
    model.train()
    optimizer.zero_grad()
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

def train_model(model, criterion, optimizer, train_loader, val_loader, epochs=100):
    train_losses = []
    for epoch in range(epochs):
        print("Epoch", epoch)
        train_loss = 0
        for x, y in train_loader:
            train_loss += train_step(model, criterion, optimizer, x, y)
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        print("Train loss:", train_loss)
    return train_losses

class MyDataset(Dataset):
    def __init__(self, df, window_size):
        self.df = df
        self.window_size = window_size
        self.data = []
        self.labels = []
        for i in range(len(df) - window_size):
            x = torch.tensor(df.iloc[i:i+window_size].values, dtype=torch.float, device=device)
            y = torch.tensor(df.iloc[i+window_size].values, dtype=torch.float, device=device)
            self.data.append(x)
            self.labels.append(y)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

class MyDataLoader(DataLoader):
    def __init__(self, dataset, window_size, batch_size, shuffle=True):
        self.dataset = dataset
        super().__init__(self.dataset, batch_size=batch_size, shuffle=shuffle)

df = pd.DataFrame(np.random.randint(0,100,size=(100, 1)))

model = MyLSTM(1, 1, 1, 1)
model.to(device)

train_data = MyDataset(df, 5)

train_loader = MyDataLoader(train_data, 5, 16)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

train_losses = train_model(model, criterion, optimizer, train_loader, None, epochs=10)

I receive the following traceback:

Traceback (most recent call last):
  File "min_mps.py", line 83, in <module>
    train_losses = train_model(model, criterion, optimizer, train_loader, None, epochs=10)
  File "min_mps.py", line 44, in train_model
    train_loss += train_step(model, criterion, optimizer, x, y)
  File "min_mps.py", line 32, in train_step
    y_pred = model(x)
  File "~/miniconda3/envs/jaxenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "min_mps.py", line 24, in forward
    out, _ = self.lstm(x, (h0, c0))
  File "~/miniconda3/envs/jaxenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/miniconda3/envs/jaxenv/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 776, in forward
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
RuntimeError: Placeholder storage has not been allocated on MPS device!

Versions

Python version: 3.10.8 (main, Nov 24 2022, 08:08:27) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-13.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] torch==1.14.0.dev20221207
[pip3] torchaudio==0.14.0.dev20221207
[pip3] torchvision==0.15.0.dev20221207
[conda] numpy                     1.22.4                   pypi_0    pypi
[conda] torch                     1.14.0.dev20221207          pypi_0    pypi
[conda] torchaudio                0.14.0.dev20221207          pypi_0    pypi
[conda] torchvision               0.15.0.dev20221207          pypi_0    pypi

Also note if relevant I'm running Mac OS 13.0. I also have tried this on the 1.13 stable release, same issue.

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework labels Dec 8, 2022
@pudepiedj
Copy link

pudepiedj commented Apr 12, 2023

I don't think this is a bug in PyTorch! You haven't allocated your torch.zeros to the device in the forward pass. If you do that, it runs, at least for me.
def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=device) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=device) out, _ = self.lstm(x, (h0, c0)) out = self.fc(out[:, -1, :]) return out

@kulinseth
Copy link
Collaborator

I don't think this is a bug in PyTorch! You haven't allocated your torch.zeros to the device in the forward pass. If you do that, it runs, at least for me.
def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=device) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=device) out, _ = self.lstm(x, (h0, c0)) out = self.fc(out[:, -1, :]) return out

That’s indeed correct . We see these errors when all the tensors are not mapped to the device . Also there were some bugs in LSTM layer which got fixed in 2.0 release. I would recommend @collindbell to try that latest release with MacOS 13.3 OS version

@pudepiedj
Copy link

pudepiedj commented Apr 13, 2023 via email

@LTsommer
Copy link

I fix it by adding device before when create each network,
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.lstm = nn.LSTM(..., device=device)

@edvard-bjarnason
Copy link

I ran into a similar issue, got the same error message "RuntimeError: Placeholder storage has not been allocated on MPS device!" when using a LSTM model. All tensors and the model were correctly mapped to the device in the code.

However, my code worked fine when I updated torch to version 2.2.1 (I was using version 2.0.1)

I have M1 Mac and was using the "mps" device. Before I updated torch, I tried to run on "cpu" device and then the output tensor from the forward pass contained NaN. I didn't look into it since this issue is fixed in latest versions of torch :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants