Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory usage and epoch iteration time increases indefinitely on M1 pro MPS #77753

Closed
alper111 opened this issue May 18, 2022 · 21 comments
Closed
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@alper111
Copy link

alper111 commented May 18, 2022

First of all, thank you for the MPS backend!

I was trying out some basic examples to see the speed. Below is my code sample (convolutional autoencoder on MNIST).

import time

import torch
import torchvision
import torchvision.transforms as transforms


device = "cpu"

train_set = torchvision.datasets.MNIST(root="./data", download=True, train=True, transform=transforms.ToTensor())
loader = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True)

model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 32, 4, 2, 1),
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 64, 4, 2, 1),
    torch.nn.ReLU(),
    torch.nn.ConvTranspose2d(64, 32, 4, 2, 1),
    torch.nn.ReLU(),
    torch.nn.ConvTranspose2d(32, 1, 4, 2, 1)
)
model.to(device)

optimizer = torch.optim.Adam(lr=0.001, params=model.parameters())
criterion = torch.nn.MSELoss()

for e in range(30):
    avg_loss = 0.0
    start = time.time()
    for i, (x, _) in enumerate(loader):
        x = x.to(device)
        x_bar = model(x)
        loss = criterion(x_bar, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()
    avg_loss /= (i+1)
    end = time.time()

    print(f"epoch={e+1}, loss={avg_loss:.5f}, time={end-start:.3f}")

Output:

epoch=1, loss=0.00718, time=19.496
epoch=2, loss=0.00063, time=19.429
epoch=3, loss=0.00036, time=19.420
epoch=4, loss=0.00027, time=19.483
epoch=5, loss=0.00020, time=19.355

This process takes around 150mb memory (and ~19s loop time) when the device is set to cpu. However, when I set it to mps, the memory usage (as I see from the activity monitor) starts from 1gb, and increases up to 7.33 gb (around 8th epoch). Also, the loop time increases from 8.45 secs (1st epoch) to 45 secs (8th epoch). Is there any other practice that I should follow (other than moving tensors and models) if I want to use mps? Also, the same pipeline on tensorflow metal counterpart takes 1.9secs:

import time

import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, _), _ = mnist.load_data()
x_train = (x_train / 255.0).reshape(-1, 28, 28, 1)

# Add a channels dimension
print(x_train.shape)
train_ds = tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(256)

model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(28, 28, 1,)))
model.add(tf.keras.layers.Conv2D(32, kernel_size=4, strides=2, padding="same", activation="relu"))
model.add(tf.keras.layers.Conv2D(64, kernel_size=4, strides=2, padding="same", activation="relu"))
model.add(tf.keras.layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding="same", activation="relu"))
model.add(tf.keras.layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding="same"))

mse = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss')


@tf.function
def train_step(x):
    with tf.GradientTape() as tape:
        x_bar = model(x)
        loss = mse(x, x_bar)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)


print(model.summary())
for e in range(30):
    start = time.time()
    train_loss.reset_states()
    for x in train_ds:
        train_step(x)
    end = time.time()
    print(f"epoch={e+1}, loss={train_loss.result():.5f}, time={end-start:.3f}")

I am only posting this to understand the differences. Thanks again for the support 🚀🔥

Versions

PyTorch version: 1.12.0.dev20220518
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.3.1 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: version 3.22.3
Libc version: N/A

Python version: 3.8.11 (default, Jul 29 2021, 14:57:32) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-12.3.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.12.0.dev20220518
[pip3] torchaudio==0.11.0
[pip3] torchvision==0.12.0
[conda] numpy 1.22.3 pypi_0 pypi
[conda] numpy-base 1.21.2 py38h6269429_0
[conda] torch 1.12.0.dev20220518 pypi_0 pypi
[conda] torchaudio 0.11.0 pypi_0 pypi
[conda] torchvision 0.12.0 pypi_0 pypi

@alper111 alper111 changed the title Memory usage and epoch iteration time increases indefinitely Memory usage and epoch iteration time increases indefinitely on M1 pro MPS May 18, 2022
@anjali411 anjali411 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework labels May 18, 2022
@thipokKub
Copy link

I tested by changing the batch size from 256 to 4096, the time usage for cpu backend is around 21 seconds (did not change), but the mps backend did decrease the time interval to around 2.5s. So I think it depends on the data overhead

Device: cpu
epoch=1, loss=0.06228, time=21.042
epoch=2, loss=0.02014, time=21.345
epoch=3, loss=0.01109, time=21.838

Device: mps
epoch=1, loss=0.07601, time=2.567
epoch=2, loss=0.03462, time=2.541
epoch=3, loss=0.01890, time=2.524

@alper111
Copy link
Author

alper111 commented May 18, 2022

Changing batch size from 256 to 4096 indeed changed both the memory usage (now, it is changing around ~1.8gb to ~2.8gb but still not constant, and that was around 160mb when it was on cpu) and the time (around 3.4s, constant throughout different epochs). However, the time it takes for a loop to complete should be the same, independent of the batch size, right? At least it is the same on the cpu.

Also, unlike nvidia gpus, why should there be any data overhead since they all reside in the same memory? I am not an expert but wasn't that the main point of apple chips? I also tried moving all the data to MPS before starting the loop, and it is still the same with 256 batch size.

@thipokKub
Copy link

thipokKub commented May 18, 2022

As far as I understand, even apple M1 that use unified memory still need to convert data between GPU, and CPU space, which still has some overhead, but it does not require data to be transfer (because it is already there)

I tested with CiFAR-10 dataset, and it seems like the average response time did indeed go up (and the loss did not reduce, but I attribute that to too large batch size, and need to adjust lr, I tried changing batch size to 64 but the loss did not decrease. However changing device to cpu did decrease the loss)

TL;DR

  • The memory consumption grow of mps, which is not seen on cpu
  • The time interval between each step grow in mps (probably due to the memory problem)
  • The loss curve did not converge as expected

Code

import time
import torch
import torchvision
import numpy as np
import torch.nn as nn
from tqdm.auto import tqdm
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision.models.resnet import ResNet, BasicBlock

BATCH_SZ = 512
MAX_EPOCHS = 10
PRINT_ITER = 5

# str_device = "mps" if torch.backends.mps.is_available() else "cpu"

loss_hist = {}
interval_hist = {}

for str_device in ["cpu", "mps"]:
    loss_hist[str_device] = []
    interval_hist[str_device] = []
    
    device = torch.device(str_device)
    print(device)

    train_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(64),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    train_ds = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=train_transforms)
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SZ, shuffle=True, num_workers=0, pin_memory=True)

    model = ResNet(BasicBlock, [1, 1, 1, 1], num_classes=10).to(device)
    model.train()
    optimizer = optim.Adam([*model.parameters()])
    criterion = nn.CrossEntropyLoss()

    time_intervals = []
    step_index = []

    step_count = 0
    for epoch in range(MAX_EPOCHS):  # loop over the dataset multiple times

        running_loss = 0.0
        avg_loss = []
        start = time.time()
        prev_time = time.time()
        with tqdm(total=len(train_loader), leave=False) as pbar:
            for i, data in enumerate(train_loader, 0):
                pbar.update()
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                avg_loss.append(loss.item())
                loss_hist[str_device].append(loss.item())
                del inputs, labels
                step_count += 1
                if i % PRINT_ITER == PRINT_ITER - 1:
                    pbar.set_description(f"[{epoch + 1}/{MAX_EPOCHS}]: loss {running_loss/PRINT_ITER:.5f}")
                    pbar.refresh()
                    curr_time = time.time()
                    time_intervals.append((curr_time - prev_time)/PRINT_ITER)
                    step_index.append(step_count)
                    interval_hist[str_device].append((step_index[-1], time_intervals[-1]))
                    prev_time = time.time()
                    running_loss = 0.0
        end = time.time()
        print(f"epoch={epoch + 1}, loss={np.mean(avg_loss):.5f}, time={end - start:.3f}s")

plt.style.use('ggplot')
plt.figure(figsize=(10, 10))

ax1 = plt.subplot(2, 1, 1)
for str_device in ["cpu", "mps"]:
    x_ticks = [x[0] for x in interval_hist[str_device]]
    y_ticks = [x[1] for x in interval_hist[str_device]]
    ax1.plot(x_ticks, y_ticks, label=str_device)
ax1.legend(loc="best")
ax1.set_title("Average step time interval (s)")

ax2 = plt.subplot(2, 1, 2)
for str_device in ["cpu", "mps"]:
    ax2.plot(loss_hist[str_device], label=str_device)
ax2.legend(loc="best")
ax2.set_title("Loss curve")

plt.tight_layout()
plt.show()

output

cpu
Files already downloaded and verified
epoch=1, loss=1.24948, time=112.687s
epoch=2, loss=0.85023, time=111.165s
epoch=3, loss=0.67690, time=111.629s
epoch=4, loss=0.56228, time=112.630s
epoch=5, loss=0.46136, time=115.914s
epoch=6, loss=0.39533, time=115.883s
epoch=7, loss=0.32071, time=111.163s
epoch=8, loss=0.26879, time=111.458s
epoch=9, loss=0.21876, time=111.202s
epoch=10, loss=0.19126, time=111.203s
mps
Files already downloaded and verified
epoch=1, loss=2.27960, time=32.032s
epoch=2, loss=2.29142, time=35.018s
epoch=3, loss=2.27697, time=46.444s
epoch=4, loss=2.27316, time=50.107s
epoch=5, loss=2.25316, time=58.493s
epoch=6, loss=2.24296, time=75.861s
epoch=7, loss=2.25106, time=81.992s
epoch=8, loss=2.23350, time=87.402s
epoch=9, loss=2.23919, time=92.495s
epoch=10, loss=2.23568, time=100.300s

Plot

Update 1

Interestingly, the increase in time interval seems to only happen to Adam optimizer. But the loss did not decrease for any of the optimizer

Optimizer time interval

I have tested the loss calculation difference between cpu, and mps but the result came up short (no difference). So it seems like the culprit for loss not decreasing is the optimizer implementation

@alper111
Copy link
Author

alper111 commented May 18, 2022

Update: The loop is around 2.4secs if I disable the backward computation and surround the forward loop with with torch.no_grad(). The memory usage still increases (albeit slowly), starting from ~300mb to ~700mb after 100 epochs.

@tcapelle
Copy link

Shameless plug:
https://github.com/tcapelle/apple_m1_pro_python
I am having issues getting good performance...

@philipturner
Copy link

philipturner commented May 19, 2022

This appears to be a memory leak, where PyTorch isn't releasing the references to the MTLBuffer objects. If they did release the reference, then ARC would automatically delete the MTLBuffer (or MPSGraphTensorData) object and free up its memory.

Alteratively, you could be continuously retaining a reference to the Python tensor object without releasing that reference. I didn't reading through this thread and the code samples thoroughly, so I probably don't know what I'm talking about regarding this hypothesis. Could you narrow down your code sample to a reproducer that demonstrates this bug, but using way less lines of Python code?

@thipokKub
Copy link

thipokKub commented May 22, 2022

The loss of multiple optimizers are now decreasing the loss value, and the time interval of most tested optimizers remain constant except for Adam. The time interval between each step is still increasing (the RAM usage is also increasing as well)

pytorch-nightly build 1.13.0.dev20220521

image

@alper111
Copy link
Author

alper111 commented May 22, 2022

I think this change will be online in 20220522, the commit is merged 15 hours ago but 20220521 is uploaded 19 hours ago.

@thipokKub
Copy link

Okay, now I updated pytorch-nightly build again (1.13.0.dev20220522), the issue seems to be fixed now

these are the results
Final result

facebook-github-bot pushed a commit that referenced this issue May 24, 2022
…#78006) (#78006)

Summary:
Fixes #77753

Pull Request resolved: #78006
Approved by: https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/cbdb694f158b8471d71822873c3ac130203cc218

Reviewed By: seemethere

Differential Revision: D36603013

Pulled By: seemethere

fbshipit-source-id: 983df9b73575a0c752490097951932052a96425f
@kulinseth
Copy link
Collaborator

Thanks @thipokKub for detailed notes and a repro case. This looks great.

@perone
Copy link
Contributor

perone commented Jun 1, 2022

I'm using PyTorch on M1 and the torch-1.13.0.dev20220601 is giving me a memory leak, while the torch-1.13.0.dev20220522 is not. So there might be a regression. I don't have time to show how to replicate but I'm basically using a transformer-based model. With CPU there is no leak. cc @kulinseth

atalman pushed a commit to atalman/pytorch that referenced this issue Jun 6, 2022
…pytorch#78006) (pytorch#78006)

Summary:
Fixes pytorch#77753

Pull Request resolved: pytorch#78006
Approved by: https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/cbdb694f158b8471d71822873c3ac130203cc218

Reviewed By: seemethere

Differential Revision: D36603013

Pulled By: seemethere

fbshipit-source-id: 983df9b73575a0c752490097951932052a96425f
malfet pushed a commit that referenced this issue Jun 7, 2022
@dbl001
Copy link

dbl001 commented Jun 13, 2022

I'm experiencing the same behavior (e.g. - increased memory usage and iteration time) on OX X 12.4 with 'MPS' on an AMD GPU. Curiously, the GPU meter is not displaying any activity.

print(torch.__version__)
1.13.0a0+gitb637810

Do I need to open a separate issue?

Screen Shot 2022-06-13 at 12 45 32 PM

@dbl001
Copy link

dbl001 commented Jun 26, 2022

I reproduced the timing and memory issue with this simple 'utils.dataLoader' example.
The CPU is ~100 times faster then 'MPS'.
My GPU is never active during this test accordint to the GPU activity monitor.
Any idea what's going on here?

import numpy as np
import torch
from torch.utils import data
import torch.utils.data as utils
import cProfile, pstats

bs = 2048

n_variables = np.loadtxt("../example_data/example1.txt_train", dtype='str').shape[1]-1
variables = np.loadtxt("../example_data/example1.txt_train", dtype = np.float32, usecols=(0,))

#epochs = 200*n_variables
epochs = n_variables
print(epochs)

for j in range(1,n_variables):
    v = np.loadtxt("../example_data/example1.txt_train", dtype = np.float32, usecols=(j,))
    variables = np.column_stack((variables,v))

f_dependent = np.loadtxt("../example_data/example1.txt_train", dtype = np.float32, usecols=(n_variables,))
f_dependent = np.reshape(f_dependent,(len(f_dependent),1))

factors = torch.from_numpy(variables)
factors = factors.to('mps')
factors = factors.float()

product = torch.from_numpy(f_dependent)
product = product.to('mps')
product = product.float()

my_dataset = utils.TensorDataset(factors,product) # create your dataset
my_dataloader = utils.DataLoader(my_dataset, batch_size=bs, shuffle=False) # create your dataloader

profiler = cProfile.Profile()
profiler.enable()

for epoch in range(epochs):
    print(epoch)
    for i, data in enumerate(my_dataloader):
        print(i)
        fct = data[0].float().to('mps')
        prd = data[1].float().to('mps')
        
profiler.disable()
stats = pstats.Stats(profiler).sort_stats('cumtime')
stats.print_stats()

You can get a copy of the training data file here:

https://github.com/SJ001/AI-Feynman/blob/master/example_data/example1.txt
Running a subset of the input training file (e.g. 10,000 rows) my cProfile results are:

181397 function calls (181331 primitive calls) in 1339.246 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000 1339.246  669.623 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3333(run_code)
        2    0.000    0.000 1339.246  669.623 {built-in method builtins.exec}
        1    0.000    0.000 1339.246 1339.246 /var/folders/3n/56fpv14n4wj0c1l1sb106pzw0000gn/T/ipykernel_35078/3788985938.py:37(<cell line: 37>)
       18    0.000    0.000 1339.242   74.402 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:647(__next__)
       18    0.027    0.002 1339.239   74.402 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:688(_next_data)
       15    0.001    0.000 1339.204   89.280 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:47(fetch)
    45/15    0.044    0.001  843.316   56.221 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:84(default_collate)
       15    0.000    0.000  843.263   56.218 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:175(<listcomp>)
       30  843.263   28.109  843.263   28.109 {built-in method torch.stack}
       15    0.070    0.005  495.887   33.059 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:49(<listcomp>)
    30003    0.133    0.000  495.818    0.017 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataset.py:189(__getitem__)
    90009  495.685    0.006  495.685    0.006 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataset.py:190(<genexpr>)
       15    0.002    0.000    0.008    0.001 {built-in method builtins.all}
       18    0.000    0.000    0.008    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:641(_next_index)
       39    0.000    0.000    0.008    0.000 {built-in method builtins.next}
       18    0.008    0.000    0.008    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/sampler.py:240(__iter__)
    30003    0.005    0.000    0.006    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:170(<genexpr>)
    30047    0.001    0.000    0.001    0.000 {built-in method builtins.len}
       18    0.000    0.000    0.001    0.000 {built-in method builtins.print}
       18    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/autograd/profiler.py:436(__init__)
       36    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:518(write)
       18    0.001    0.000    0.001    0.000 {built-in method torch.zeros}
       30    0.001    0.000    0.001    0.000 {method 'float' of 'torch._C._TensorBase' objects}
       36    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/_ops.py:143(__call__)
       36    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:448(_schedule_flush)
       15    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:202(schedule)
       18    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/autograd/profiler.py:445(__enter__)
       18    0.001    0.000    0.001    0.000 {built-in method torch._ops.profiler._record_function_enter}
       15    0.001    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/zmq/sugar/socket.py:543(send)
       18    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/autograd/profiler.py:449(__exit__)
       30    0.000    0.000    0.000    0.000 {method 'to' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:425(__iter__)
      192    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:379(_get_iterator)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:680(__init__)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:592(__init__)
       18    0.000    0.000    0.000    0.000 {built-in method torch._ops.profiler._record_function_exit}
       36    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/abc.py:117(__instancecheck__)
       36    0.000    0.000    0.000    0.000 {built-in method _abc._abc_instancecheck}
    34/16    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/abc.py:121(__subclasscheck__)
    34/16    0.000    0.000    0.000    0.000 {built-in method _abc._abc_subclasscheck}
       15    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/threading.py:1126(is_alive)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/codeop.py:142(__call__)
        3    0.000    0.000    0.000    0.000 {method 'random_' of 'torch._C._TensorBase' objects}
        2    0.000    0.000    0.000    0.000 {built-in method builtins.compile}
       36    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:429(_is_master_process)
        3    0.000    0.000    0.000    0.000 {built-in method torch.empty}
       15    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/threading.py:1059(_wait_for_tstate_lock)
       15    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:90(_event_pipe)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3186(_update_code_co_name)
        3    0.000    0.000    0.000    0.000 {method 'item' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/sampler.py:75(__iter__)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:71(create_fetcher)
       15    0.000    0.000    0.000    0.000 {method 'acquire' of '_thread.lock' objects}
       36    0.000    0.000    0.000    0.000 {built-in method posix.getpid}
       36    0.000    0.000    0.000    0.000 {method 'write' of '_io.StringIO' objects}
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:562(_get_shared_seed)
       30    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py:83(get_worker_info)
        4    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/dis.py:449(findlinestarts)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:44(__init__)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataset.py:192(__len__)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/contextlib.py:261(helper)
       21    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
       36    0.000    0.000    0.000    0.000 {method '__exit__' of '_thread.RLock' objects}
       15    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/_collections_abc.py:315(__subclasshook__)
       15    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/threading.py:529(is_set)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/contextlib.py:123(__exit__)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/contextlib.py:114(__enter__)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/contextlib.py:86(__init__)
       17    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
        3    0.000    0.000    0.000    0.000 {method 'format' of 'str' objects}
       15    0.000    0.000    0.000    0.000 {method 'append' of 'collections.deque' objects}
        1    0.000    0.000    0.000    0.000 /var/folders/3n/56fpv14n4wj0c1l1sb106pzw0000gn/T/ipykernel_35078/3788985938.py:44(<cell line: 44>)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/traitlets/traitlets.py:566(__get__)
        4    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/compilerop.py:174(extra_flags)
        3    0.000    0.000    0.000    0.000 {method 'size' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:8(__init__)
        2    0.000    0.000    0.000    0.000 {method 'replace' of 'code' objects}
       19    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/_collections_abc.py:409(__subclasshook__)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3284(compare)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/interactiveshell.py:1181(user_global_ns)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/traitlets/traitlets.py:535(get)
        6    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:440(_auto_collation)
        4    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:444(_index_sampler)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


Running on the CPU my cProfile results are:

  1734737 function calls (1734455 primitive calls) in 1.755 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    1.754    0.877 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3333(run_code)
        2    0.000    0.000    1.754    0.877 {built-in method builtins.exec}
        1    0.002    0.002    1.754    1.754 /var/folders/3n/56fpv14n4wj0c1l1sb106pzw0000gn/T/ipykernel_35078/1615822328.py:35(<cell line: 35>)
      144    0.002    0.000    1.752    0.012 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:647(__next__)
      144    0.141    0.001    1.745    0.012 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:688(_next_data)
      141    0.002    0.000    1.548    0.011 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:47(fetch)
      141    0.090    0.001    1.145    0.008 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:49(<listcomp>)
   288000    0.277    0.000    1.056    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataset.py:189(__getitem__)
   864000    0.779    0.000    0.779    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataset.py:190(<genexpr>)
  423/141    0.238    0.001    0.401    0.003 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:84(default_collate)
      141    0.000    0.000    0.091    0.001 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:175(<listcomp>)
      282    0.090    0.000    0.090    0.000 {built-in method torch.stack}
      141    0.018    0.000    0.071    0.001 {built-in method builtins.all}
      144    0.000    0.000    0.056    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:641(_next_index)
      291    0.000    0.000    0.056    0.000 {built-in method builtins.next}
      144    0.056    0.000    0.056    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/sampler.py:240(__iter__)
   288000    0.039    0.000    0.053    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:170(<genexpr>)
   288014    0.014    0.000    0.014    0.000 {built-in method builtins.len}
      288    0.000    0.000    0.002    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/_ops.py:143(__call__)
      144    0.000    0.000    0.002    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/autograd/profiler.py:445(__enter__)
      144    0.001    0.000    0.002    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/autograd/profiler.py:449(__exit__)
     1422    0.001    0.000    0.001    0.000 {built-in method builtins.isinstance}
      144    0.001    0.000    0.001    0.000 {built-in method torch._ops.profiler._record_function_enter}
      144    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/autograd/profiler.py:436(__init__)
      144    0.001    0.000    0.001    0.000 {built-in method torch.zeros}
      288    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/abc.py:117(__instancecheck__)
      144    0.001    0.000    0.001    0.000 {built-in method torch._ops.profiler._record_function_exit}
      288    0.000    0.000    0.001    0.000 {built-in method _abc._abc_instancecheck}
        3    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:425(__iter__)
        3    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:379(_get_iterator)
        3    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:680(__init__)
        3    0.000    0.000    0.001    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:592(__init__)
      141    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/abc.py:121(__subclasscheck__)
      141    0.000    0.000    0.000    0.000 {built-in method _abc._abc_subclasscheck}
      282    0.000    0.000    0.000    0.000 {method 'float' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 {built-in method builtins.print}
        6    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:518(write)
        3    0.000    0.000    0.000    0.000 {built-in method torch.empty}
        3    0.000    0.000    0.000    0.000 {method 'random_' of 'torch._C._TensorBase' objects}
        6    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:448(_schedule_flush)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:202(schedule)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/zmq/sugar/socket.py:543(send)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/codeop.py:142(__call__)
      141    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/_collections_abc.py:315(__subclasshook__)
        2    0.000    0.000    0.000    0.000 {built-in method builtins.compile}
      282    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py:83(get_worker_info)
      147    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
      143    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3186(_update_code_co_name)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/sampler.py:75(__iter__)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:562(_get_shared_seed)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:71(create_fetcher)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/threading.py:1126(is_alive)
        3    0.000    0.000    0.000    0.000 {method 'item' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataset.py:192(__len__)
        6    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:429(_is_master_process)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:44(__init__)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/contextlib.py:261(helper)
        4    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/dis.py:449(findlinestarts)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/threading.py:1059(_wait_for_tstate_lock)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/contextlib.py:86(__init__)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/contextlib.py:114(__enter__)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/ipykernel/iostream.py:90(_event_pipe)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/traitlets/traitlets.py:566(__get__)
        3    0.000    0.000    0.000    0.000 {method 'format' of 'str' objects}
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/contextlib.py:123(__exit__)
        4    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/compilerop.py:174(extra_flags)
        3    0.000    0.000    0.000    0.000 {method 'size' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 /var/folders/3n/56fpv14n4wj0c1l1sb106pzw0000gn/T/ipykernel_35078/1615822328.py:42(<cell line: 42>)
        3    0.000    0.000    0.000    0.000 {method 'acquire' of '_thread.lock' objects}
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:8(__init__)
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/traitlets/traitlets.py:535(get)
        2    0.000    0.000    0.000    0.000 {method 'replace' of 'code' objects}
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/interactiveshell.py:1181(user_global_ns)
        6    0.000    0.000    0.000    0.000 {built-in method posix.getpid}
        6    0.000    0.000    0.000    0.000 {method 'write' of '_io.StringIO' objects}
        6    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:440(_auto_collation)
        4    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        2    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3284(compare)
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:444(_index_sampler)
        6    0.000    0.000    0.000    0.000 {method '__exit__' of '_thread.RLock' objects}
        3    0.000    0.000    0.000    0.000 {method 'append' of 'collections.deque' objects}
        3    0.000    0.000    0.000    0.000 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/threading.py:529(is_set)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

Why does {built-in method torch.stack} take so long on the GPU?


       18    0.000    0.000 1339.242   74.402 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:647(__next__)
       18    0.027    0.002 1339.239   74.402 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/dataloader.py:688(_next_data)
       15    0.001    0.000 1339.204   89.280 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:47(fetch)
    45/15    0.044    0.001  843.316   56.221 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:84(default_collate)
       15    0.000    0.000  843.263   56.218 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py:175(<listcomp>)
       30  843.263   28.109  843.263   28.109 {built-in method torch.stack}
       15    0.070    0.005  495.887   33.059 /Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:49(<listcomp>)

@philipturner
Copy link

Does torch.stack create a Python array of tensors, then concatenate them into one larger tensor? If so, check the dimensions that you are stacking. If it's something like 1000 tensor objects, then most overhead is in allocating and managing the underlying MTLBuffer objects, and encoding the command(s) to combine them. Even with PyTorch's optimized heap allocator, driver-side latency may still accumulate. Or if allocation is not the bottleneck, multiply the number of stacked tensors by 20 microseconds (command encoding overhead). Is the resulting time span large? If you're using Torch in eager mode, this latency can't be optimized away.

@dbl001
Copy link

dbl001 commented Jun 28, 2022

torch.stack is called in the default_collate() method with Automatic batching (default).
My training file is 96000x3.
I tried changing the batch size (e.g. 2048, 128, as well as None). With 'automatic batching', the process always stalls when it grows to ~48GB.
How are these tensors garbage collected? At the end of every batch? Every epoch? GPU memory pressure? Never?
Are there any methods to display GPU memory usage?

Screen Shot 2022-06-28 at 6 58 30 AM

@kulinseth
Copy link
Collaborator

Hi, @dbl001 , we use the recommendedMaxWorkingSetSize in Metal to cap the limit of allocation on the MTLHeap. You can use PYTORCH_DEBUG_MPS_ALLOCATOR env variable to enable the Debug prints of the MTLHeap allocation in PyTorch.
Also for memory allocations, you can use tools like heap and footprint to track the usage. For GPU memory you can use IOAccelMemory --pid <PID>.

@kulinseth
Copy link
Collaborator

How are these tensors garbage collected? At the end of every batch? Every epoch? GPU memory pressure? Never? Are there any methods to display GPU memory usage?

For tensors garbage collection, we follow the Tensor's life-cycle and release the memory back when it goes out of use.

@dbl001
Copy link

dbl001 commented Jun 28, 2022

Any clues as to why the computation continues to slow, the process memory continues to grow, then stops.
The GPU activity meter does not show any activity.
It happens with 'mps' but not with 'cpu', for the same data and same parameters.

@dbl001
Copy link

dbl001 commented Jun 28, 2022

Please find attached the output from

% IOAccelMemory --pid 80028 > output.txt

output.txt

@philipturner
Copy link

philipturner commented Jun 28, 2022

@dbl001 would you mind using a details element in Markdown to shorten the raw output you just posted? It's difficult to navigate the thread with such a large comment. The following HTML tags turn it into a dropdown:

<details>
<summary>File contents</summary>
<!-- Leave a blank line here; delete this HTML comment -->
```
4 KB
4 KB
...
4 KB
```
</details>
File contents
4 KB
4 KB
...
4 KB

@dbl001
Copy link

dbl001 commented Jun 28, 2022

Here's the output of

$ IOAccelMemory --pid 86577 > output1.txt

output1.txt

Running only the dataLoader to 'mps', currently at process size 11gb.

import numpy as np
import torch
from torch.utils import data
import torch.utils.data as utils
import cProfile, pstats

bs = 2048

n_variables = np.loadtxt("../example_data/example1.txt_train", dtype='str').shape[1]-1
variables = np.loadtxt("../example_data/example1.txt_train", dtype = np.float32, usecols=(0,))

#epochs = 200*n_variables
epochs = n_variables
print(epochs)

for j in range(1,n_variables):
    v = np.loadtxt("../example_data/example1.txt_train", dtype = np.float32, usecols=(j,))
    variables = np.column_stack((variables,v))

f_dependent = np.loadtxt("../example_data/example1.txt_train", dtype = np.float32, usecols=(n_variables,))
f_dependent = np.reshape(f_dependent,(len(f_dependent),1))

factors = torch.from_numpy(variables)
factors = factors.to('mps')
factors = factors.float()

product = torch.from_numpy(f_dependent)
product = product.to('mps')
product = product.float()

my_dataset = utils.TensorDataset(factors,product) # create your dataset
my_dataloader = utils.DataLoader(my_dataset, batch_size=bs, shuffle=False) # create your dataloader

profiler = cProfile.Profile()
profiler.enable()

for epoch in range(epochs):
    print(epoch)
    for i, data in enumerate(my_dataloader):
        #print(i)
        fct = data[0].float().to('mps')
        prd = data[1].float().to('mps')
        
profiler.disable()
stats = pstats.Stats(profiler).sort_stats('cumtime')
stats.print_stats()

I'll run IOAccelMemory again when it gets to ~48gb.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants