Skip to content

Commit

Permalink
Update on "nn.Linear with BSR inputs: spare the user from explicit Tr…
Browse files Browse the repository at this point in the history
…iton kernel registrations"




<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 08f7a6a</samp>

This pull request adds support for triton kernels in `torch` and `torch/cuda`, and refactors and tests the existing triton kernel for BSR matrix multiplication. It also adds a test case to ensure that importing `torch` does not implicitly import `triton`.

cc alexsamardzic pearu cpuhrsch amjames bhosmer albanD mruberry jbschlosser walterddr mikaylagawarecki

[ghstack-poisoned]
  • Loading branch information
nikitaved committed May 31, 2023
2 parents 6f9af1d + 460c982 commit cec539a
Show file tree
Hide file tree
Showing 39 changed files with 616 additions and 225 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7d1a95b04654ff9c216afe08a454ad0822f05370
9820899b3845e461d9031dba66062efade65d420
34 changes: 34 additions & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
FROM mcr.microsoft.com/vscode/devcontainers/miniconda:0-3

# I am suprised this is needed
RUN conda init

# Copy environment.yml (if found) to a temp location so we update the environment. Also
# copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
COPY .devcontainer/cuda/environment.yml .devcontainer/noop.txt /tmp/conda-tmp/
RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \
&& sudo rm -rf /tmp/conda-tmp

# Tools needed for llvm
RUN sudo apt-get -y update
RUN sudo apt install -y lsb-release wget software-properties-common gnupg

# Install CLANG if version is specified
ARG CLANG_VERSION
RUN if [ -n "$CLANG_VERSION" ]; then \
sudo wget https://apt.llvm.org/llvm.sh; \
chmod +x llvm.sh; \
sudo ./llvm.sh "${CLANG_VERSION}"; \
echo 'export CC=clang' >> ~/.bashrc; \
echo 'export CXX=clang++' >> ~/.bashrc; \
sudo apt update; \
sudo apt install -y clang; \
sudo apt install -y libomp-dev; \
fi


# Install cuda if version is specified
ARG CUDA_VERSION
RUN if [ -n "$CUDA_VERSION" ]; then \
conda install cuda -c "nvidia/label/cuda-${CUDA_VERSION}"; \
fi
37 changes: 37 additions & 0 deletions .devcontainer/cpu/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
{
"name": "PyTorch - CPU",
"build": {
"context": "../..",
"dockerfile": "../Dockerfile",
"args": {
"USERNAME": "vscode",
"BUILDKIT_INLINE_CACHE": "0",
"CLANG_VERSION": ""
}
},

// Features to add to the dev container. More info: https://containers.dev/features.
"features": {
// This is needed for lintrunner
"ghcr.io/devcontainers/features/rust:1" : {}
},

// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "bash .devcontainer/scripts/install-dev-tools.sh",

// Configure tool-specific properties.
// "customizations": {},
"customizations": {
"vscode": {
"extensions": ["streetsidesoftware.code-spell-checker"]
}
}

// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}
6 changes: 6 additions & 0 deletions .devcontainer/cpu/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# This environment is specific to Debian
name: PyTorch
dependencies:
- cmake
- ninja
- libopenblas
37 changes: 37 additions & 0 deletions .devcontainer/cuda/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
{
"name": "PyTorch - CUDA",
"build": {
"context": "../..",
"dockerfile": "../Dockerfile",
"args": {
"USERNAME": "vscode",
"BUILDKIT_INLINE_CACHE": "0",
"CUDA_VERSION": "11.8.0",
"CLANG_VERSION": ""
}
},
"runArgs": ["--gpus", "all"],
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "bash .devcontainer/scripts/install-dev-tools.sh",

// Configure tool-specific properties.
// "customizations": {},
"customizations": {
"vscode": {
"extensions": ["streetsidesoftware.code-spell-checker"]
}
},

// Features to add to the dev container. More info: https://containers.dev/features.
"features": {
// This is needed for lintrunner
"ghcr.io/devcontainers/features/rust:1" : {}
}
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}
6 changes: 6 additions & 0 deletions .devcontainer/cuda/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# This environment is specific to Debian
name: PyTorch
dependencies:
- cmake
- ninja
- libopenblas
3 changes: 3 additions & 0 deletions .devcontainer/noop.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
This file copied into the container along with environment.yml* from the parent
folder. This file is included to prevents the Dockerfile COPY instruction from
failing if no environment.yml is found.
11 changes: 11 additions & 0 deletions .devcontainer/scripts/install-dev-tools.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash
# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below
pip install -r requirements.txt
git submodule sync
git submodule update --init --recursive

# This takes some time
make setup_lint

# Add CMAKE_PREFIX_PATH to bashrc
echo 'export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}' >> ~/.bashrc
38 changes: 38 additions & 0 deletions .devcontainer/scripts/update_alternatives_clang.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env bash
# update_alternatives_clang.sh
# chmod u+x update_alternatives_clang.sh
#

update_alternatives() {
local version=${1}
local priority=${2}
local z=${3}
local slaves=${4}
local path=${5}
local cmdln

cmdln="--verbose --install ${path}${master} ${master} ${path}${master}-${version} ${priority}"
for slave in ${slaves}; do
cmdln="${cmdln} --slave ${path}${slave} ${slave} ${path}${slave}-${version}"
done
sudo update-alternatives ${cmdln}
}

if [[ ${#} -ne 2 ]]; then
echo usage: "${0}" clang_version priority
exit 1
fi

version=${1}
priority=${2}
path="/usr/bin/"

master="llvm-config"
slaves="llvm-addr2line llvm-ar llvm-as llvm-bcanalyzer llvm-bitcode-strip llvm-cat llvm-cfi-verify llvm-cov llvm-c-test llvm-cvtres llvm-cxxdump llvm-cxxfilt llvm-cxxmap llvm-debuginfod llvm-debuginfod-find llvm-diff llvm-dis llvm-dlltool llvm-dwarfdump llvm-dwarfutil llvm-dwp llvm-exegesis llvm-extract llvm-gsymutil llvm-ifs llvm-install-name-tool llvm-jitlink llvm-jitlink-executor llvm-lib llvm-libtool-darwin llvm-link llvm-lipo llvm-lto llvm-lto2 llvm-mc llvm-mca llvm-ml llvm-modextract llvm-mt llvm-nm llvm-objcopy llvm-objdump llvm-omp-device-info llvm-opt-report llvm-otool llvm-pdbutil llvm-PerfectShuffle llvm-profdata llvm-profgen llvm-ranlib llvm-rc llvm-readelf llvm-readobj llvm-reduce llvm-remark-size-diff llvm-rtdyld llvm-sim llvm-size llvm-split llvm-stress llvm-strings llvm-strip llvm-symbolizer llvm-tapi-diff llvm-tblgen llvm-tli-checker llvm-undname llvm-windres llvm-xray"

update_alternatives "${version}" "${priority}" "${master}" "${slaves}" "${path}"

master="clang"
slaves="analyze-build asan_symbolize bugpoint c-index-test clang++ clang-apply-replacements clang-change-namespace clang-check clang-cl clang-cpp clangd clang-doc clang-extdef-mapping clang-format clang-format-diff clang-include-fixer clang-linker-wrapper clang-move clang-nvlink-wrapper clang-offload-bundler clang-offload-packager clang-offload-wrapper clang-pseudo clang-query clang-refactor clang-rename clang-reorder-fields clang-repl clang-scan-deps clang-tidy count diagtool dsymutil FileCheck find-all-symbols git-clang-format hmaptool hwasan_symbolize intercept-build ld64.lld ld.lld llc lld lldb lldb-argdumper lldb-instr lldb-server lldb-vscode lld-link lli lli-child-target modularize not obj2yaml opt pp-trace run-clang-tidy sancov sanstats scan-build scan-build-py scan-view split-file UnicodeNameMappingGenerator verify-uselistorder wasm-ld yaml2obj yaml-bench"

update_alternatives "${version}" "${priority}" "${master}" "${slaves}" "${path}"
2 changes: 1 addition & 1 deletion aten/src/ATen/SparseCsrTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
const Tensor& values() const {
return values_;
}
int nnz() {
int64_t nnz() {
return col_indices_.size(-1);
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1554,11 +1554,11 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {

static Vectorized<T> C10_ALWAYS_INLINE
loadu(const void* ptr, int count = size()) {
return Vectorized<T>{vinner_type::loadu(ptr)};
return Vectorized<T>{vinner_type::loadu(ptr, count)};
}

void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
_vec.store(ptr);
_vec.store(ptr, count);
}

Vectorized<T> relu(Vectorized<T> zero_point) const {
Expand Down
13 changes: 6 additions & 7 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3089,12 +3089,12 @@ Tensor linalg_tensorinv(const Tensor& self, int64_t ind) {
TORCH_CHECK(ind > 0, "Expected a strictly positive integer for 'ind', but got ", ind);

// self[ind:]
std::vector<int64_t> shape_ind_end = self.sizes().slice(ind).vec();
std::vector<c10::SymInt> shape_ind_end = self.sym_sizes().slice(ind).vec();
// self[:ind]
std::vector<int64_t> shape_start_ind = self.sizes().slice(0, ind).vec();
std::vector<c10::SymInt> shape_start_ind = self.sym_sizes().slice(0, ind).vec();

int64_t prod_ind_end = c10::multiply_integers(shape_ind_end.cbegin(), shape_ind_end.cend());
int64_t prod_start_ind = c10::multiply_integers(shape_start_ind.cbegin(), shape_start_ind.cend());
c10::SymInt prod_ind_end = c10::multiply_integers(shape_ind_end.cbegin(), shape_ind_end.cend());
c10::SymInt prod_start_ind = c10::multiply_integers(shape_start_ind.cbegin(), shape_start_ind.cend());

// Check whether the self tensor can be reshaped to the 2D square matrix
TORCH_CHECK(prod_ind_end == prod_start_ind,
Expand All @@ -3106,11 +3106,10 @@ Tensor linalg_tensorinv(const Tensor& self, int64_t ind) {
shape_ind_end.insert(shape_ind_end.cend(), shape_start_ind.cbegin(), shape_start_ind.cend());

// If the reshaped self is not invertible catch this error
Tensor result, info;
std::tie(result, info) = at::linalg_inv_ex(self.reshape({prod_ind_end, prod_ind_end}), /*check_errors=*/false);
auto [result, info] = at::linalg_inv_ex(self.reshape_symint({prod_ind_end, prod_ind_end}), /*check_errors=*/false);
at::_linalg_check_errors(info, "inv", /*is_matrix*/true);

return result.reshape(shape_ind_end);
return result.reshape_symint(shape_ind_end);
}

// TODO: implement _out variant avoiding copy and using already allocated storage directly
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
DimVector compressed_indices_batchsize = DimVector(compressed_indices.sizes().slice(0, batch_ndim));
DimVector plain_indices_batchsize = DimVector(plain_indices.sizes().slice(0, batch_ndim));
DimVector values_batchsize = DimVector(values.sizes().slice(0, batch_ndim));
const int values_nnz = values.size(batch_ndim);
const int64_t values_nnz = values.size(batch_ndim);
DimVector values_blocksize = DimVector(values.sizes().slice(batch_ndim + 1, block_ndim));
DimVector values_densesize = DimVector(values.sizes().slice(batch_ndim + 1 + block_ndim, dense_ndim));
TORCH_CHECK(
Expand All @@ -229,9 +229,9 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
") must be divisible with blocksize[", i, "] (=", blocksize[i],
") as defined by values shape");
}
const int nrows = size[batch_ndim] / blocksize[0];
const int ncols = size[batch_ndim + 1] / blocksize[1];
int compressed_dim_size, plain_dim_size;
const int64_t nrows = size[batch_ndim] / blocksize[0];
const int64_t ncols = size[batch_ndim + 1] / blocksize[1];
int64_t compressed_dim_size, plain_dim_size;
std::tie(compressed_dim_size, plain_dim_size) = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args",
[&] { return std::make_tuple(nrows, ncols); },
[&] { return std::make_tuple(ncols, nrows); });
Expand Down
1 change: 1 addition & 0 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class CI(NamedTuple):
"hf_T5_base", # accuracy
"mobilenet_v3_large", # accuracy
"resnet50_quantized_qat", # Eager model failed to run
"AlbertForQuestionAnswering", # accuracy
"crossvit_9_240", # fails to run on timm 0.8.22 with cudagraphs, mempools
"deit_base_distilled_patch16_224", # fails to run in timm 0.8.22, cudagraphs
"mobilevit_s",
Expand Down
44 changes: 24 additions & 20 deletions test/distributed/_tensor/test_device_mesh.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import sys

import torch
from torch.distributed._tensor.device_mesh import DeviceMesh
Expand All @@ -11,21 +10,29 @@
get_global_rank,
get_process_group_ranks,
get_world_size,
init_process_group,
is_initialized,
is_nccl_available,
ProcessGroup,
)
from torch.testing._internal.common_distributed import TEST_SKIPS
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed.fake_pg import FakeStore


def _get_device_type_and_backend():
device_type = "cuda" if torch.cuda.is_available() else "cpu"
backend = "nccl" if device_type == "cuda" else "gloo"
return device_type, backend
def _get_device_type(world_size):
if (
torch.cuda.is_available()
and torch.cuda.device_count() >= world_size
and is_nccl_available()
):
device_type = "cuda"
else:
device_type = "cpu"
return device_type


def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0):
Expand All @@ -40,20 +47,12 @@ class DeviceMeshTest(DTensorTestBase):
def world_size(self):
return 4

@with_comms
def test_eligible_default_pg_for_mesh(self):
mesh_tensor = torch.arange(self.world_size).reshape(2, -1)
mesh = DeviceMesh(self.device_type, mesh_tensor)

def test_init_process_group(self):
device_type, backend = _get_device_type_and_backend()
# skip the test if not enough GPUs
if backend == "nccl" and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
device_type = _get_device_type(self.world_size)
mesh_tensor = torch.arange(4).reshape(2, 2)
self.assertTrue(not is_initialized())
_set_env_var(world_size=self.world_size, rank=self.rank)
mesh = DeviceMesh(device_type, mesh_tensor)
DeviceMesh(device_type, mesh_tensor)
self.assertTrue(is_initialized())
self.destroy_pg()

Expand Down Expand Up @@ -89,10 +88,15 @@ def test_lazy_init_device_mesh(self):
with self.assertRaisesRegex(RuntimeError, "process groups not initialized!"):
mesh.get_dim_groups()

if self.rank == 1:
assert mesh.get_coordinate() is not None
else:
assert mesh.get_coordinate() is None
def test_fake_pg_device_mesh(self):
fake_store = FakeStore()
init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
mesh = DeviceMesh(device_type, torch.arange(self.world_size))

local_tensor = torch.randn(2, 8)
global_tensor = mesh.all_gather(local_tensor)
self.assertEqual(global_tensor.shape, (self.world_size * 2, 8))

@with_comms
def test_validate_device_mesh(self):
Expand Down

0 comments on commit cec539a

Please sign in to comment.