Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG Report using backend:cudaMallocAsync #124351

Open
hawkheimmer opened this issue Apr 18, 2024 · 0 comments
Open

BUG Report using backend:cudaMallocAsync #124351

hawkheimmer opened this issue Apr 18, 2024 · 0 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: CUDACachingAllocator triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hawkheimmer
Copy link

hawkheimmer commented Apr 18, 2024

🐛 Describe the bug

When I use os.environ['PYTORCH_CUDA_ALLOC_CONF']="backend:native", things went well; but when I switch to os.environ['PYTORCH_CUDA_ALLOC_CONF']="backend:cudaMallocAsync", things went wrong.

Full code:

import math
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
from sklearn.metrics import f1_score, confusion_matrix
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.nn import TransformerEncoderLayer, TransformerEncoder
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

os.environ['PYTORCH_CUDA_ALLOC_CONF']="backend:cudaMallocAsync"
print(os.environ.get('PYTORCH_CUDA_ALLOC_CONF'))

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

copper = pd.read_csv('copper.csv', header=None)
cotton = pd.read_csv('cotton.csv', header=None)
glass = pd.read_csv('glass.csv', header=None)
leather = pd.read_csv('leather.csv', header=None)
paper = pd.read_csv('paper.csv', header=None)
ptfe = pd.read_csv('ptfe.csv', header=None)
pu = pd.read_csv('pu.csv', header=None)
wood = pd.read_csv('wood.csv', header=None)
wool = pd.read_csv('wool.csv', header=None)

cotton = cotton.iloc[:, :-2]
leather = leather.iloc[:, :-3]
paper = paper.iloc[:, :-1]
pu = pu.iloc[:-1, :]
wood = wood.iloc[:, :-7]
wool = wool.iloc[:, :-3]

repeat_count = 500

labels = []
for i in range(0, 9):
    labels.extend([i] * repeat_count)

data = [copper, cotton, glass, leather, paper, ptfe, pu, wood, wool]

tensors = [torch.tensor(df.values) for df in data]
combined_tensor = torch.cat(tensors, dim=1)
combined_tensor = combined_tensor.unsqueeze(1)

result_tensor = combined_tensor.permute(2, 1, 0)

mean_val = torch.mean(result_tensor)
std_val = torch.std(result_tensor)

standardized_tensor = (result_tensor - mean_val) / std_val

labels = torch.tensor(labels)

dataset = TensorDataset(standardized_tensor, labels)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv1d(1, 32, 4, 2)
        self.bn1 = nn.BatchNorm1d(32)
        self.conv2 = nn.Conv1d(32, 128, 4, 2)
        self.bn2 = nn.BatchNorm1d(128)
        self.conv3 = nn.Conv1d(128, 256, 4)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool = nn.MaxPool1d(4, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        return x


class PositionalEncoding(nn.Module):

    def __init__(self, d_model=256, dropout=0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x.permute(2, 0, 1)
        x = x + self.pe[:x.size(0)]
        # print(x.shape)
        return self.dropout(x)


class TransformerClassifier(nn.Module):
    def __init__(self, embedding_dim=256, num_heads=32, num_classes=9, num_layers=4,
                 dim_feedforward=256, dropout=0.1, hidden_dim=256):
        super(TransformerClassifier, self).__init__()
        self.encoder_layer = TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
                                                     dim_feedforward=dim_feedforward, dropout=dropout,
                                                     batch_first=False)
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.input_layer = nn.Linear(embedding_dim, hidden_dim)
        self.hidden_layer = nn.Linear(hidden_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, num_classes)
        self.conv = ConvNet()
        self.positional_encoding = PositionalEncoding()

    def forward(self, x):
        x = x.float()
        x = self.conv(x)
        x = self.positional_encoding(x)
        x = self.transformer_encoder(x)
        x = x.permute(1, 2, 0)
        x = F.avg_pool1d(x, kernel_size=x.size(2))
        x = x.squeeze(2)
        x = F.relu(self.input_layer(x))
        x = F.relu(self.hidden_layer(x))
        x = self.output_layer(x)
        return x


model = TransformerClassifier().to(torch.device('cuda'))

Traceback:

Traceback (most recent call last):
  File "/home/user/Document/AI_sensor/data.py", line 165, in <module>
    model = TransformerClassifier().to(torch.device('cuda'))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1152, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
  File "/home/user/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
  File "/home/user/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
  File "/home/user/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 825, in _apply
    param_applied = fn(param)
                    ^^^^^^^^^
  File "/home/user/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1150, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/cuda/__init__.py", line 302, in _lazy_init
    torch._C._cuda_init()
RuntimeError: config[i] == get()->name() INTERNAL ASSERT FAILED at "../c10/cuda/CUDAAllocatorConfig.cpp":225, please report a bug to PyTorch. Allocator backend parsed at runtime != allocator backend parsed at load time

Versions

Collecting environment information...
PyTorch version: 2.2.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.11.8 (main, Feb 26 2024, 21:39:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 10.1.243
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3050 Laptop GPU
Nvidia driver version: 535.171.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
架构: x86_64
CPU 运行模式: 32-bit, 64-bit
字节序: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU: 12
在线 CPU 列表: 0-11
每个核的线程数: 2
每个座的核数: 6
座: 1
NUMA 节点: 1
厂商 ID: AuthenticAMD
CPU 系列: 25
型号: 68
型号名称: AMD Ryzen 5 6600H with Radeon Graphics
步进: 1
Frequency boost: enabled
CPU MHz: 1835.786
CPU 最大 MHz: 4563.2808
CPU 最小 MHz: 1600.0000
BogoMIPS: 6587.55
虚拟化: AMD-V
L1d 缓存: 192 KiB
L1i 缓存: 192 KiB
L2 缓存: 3 MiB
L3 缓存: 16 MiB
NUMA 节点0 CPU: 0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
标记: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.2
[pip3] torchaudio==2.2.2
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.17.2
[pip3] triton==2.2.0
[conda] blas 1.0 mkl
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py311h5eee18b_1
[conda] mkl_fft 1.3.8 py311h5eee18b_0
[conda] mkl_random 1.2.4 py311hdb19cb5_0
[conda] numpy 1.26.4 py311h08b1b3b_0
[conda] numpy-base 1.26.4 py311hf175353_0
[conda] torch 2.2.2 pypi_0 pypi
[conda] torchaudio 2.2.2 pypi_0 pypi
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.17.2 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi

cc @ptrblck

@jbschlosser jbschlosser added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: CUDACachingAllocator labels Apr 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: CUDACachingAllocator triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants