Skip to content

Commit

Permalink
Merge pull request #49 from teddykoker/non-default-cuda
Browse files Browse the repository at this point in the history
Fix incorrect results when running on non-default GPU
  • Loading branch information
teddykoker committed Feb 28, 2022
2 parents 87f7715 + 551eaa8 commit 7ab8008
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def ext_modules():

setup(
name="torchsort",
version="0.1.8",
version="0.1.9",
description="Differentiable sorting and ranking in PyTorch",
author="Teddy Koker",
url="https://github.com/teddykoker/torchsort",
Expand Down
6 changes: 5 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
REGULARIZATION = ["l2", "kl"]
REGULARIZATION_STRENGTH = [1e-1, 1e0, 1e1]

# use CPU, and up to two CUDA devices
DEVICES = [torch.device("cpu")] + (
[torch.device("cuda")] if torch.cuda.is_available() else []
[torch.device(f"cuda:{d}") for d in range(min(torch.cuda.device_count(), 2))]
)

torch.manual_seed(0)
Expand Down Expand Up @@ -67,6 +68,7 @@ def test_vs_original(funcs, regularization, regularization_strength, device):
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA to test fp16")
def test_half(function, regularization, regularization_strength, device):
# check half precision
x = torch.randn(BATCH_SIZE, SEQ_LEN, requires_grad=True).cuda().half()
f = partial(
function,
Expand All @@ -75,3 +77,5 @@ def test_half(function, regularization, regularization_strength, device):
)
# don't think theres a better way of testing, tolerance must be pretty high
assert torch.allclose(f(x), f(x.float()).half(), atol=1e-1)


8 changes: 5 additions & 3 deletions torchsort/isotonic_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@


#include <torch/extension.h>
// #include <cuda.h>
// #include <cuda_runtime.h>
// #include <iostream>
#include <c10/cuda/CUDAGuard.h>

// Copied from fast-soft-sort (https://bit.ly/3r0gOav) with the following modifications:
// - replace numpy functions with torch equivalents
Expand Down Expand Up @@ -306,6 +304,7 @@ __global__ void isotonic_kl_backward_kernel(
// Solves an isotonic regression problem using PAV.
// Formally, it solves argmin_{v_1 >= ... >= v_n} 0.5 ||v - y||^2.
torch::Tensor isotonic_l2(torch::Tensor y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(y));
auto batch = y.size(0);
auto n = y.size(1);
auto sol = torch::zeros_like(y);
Expand All @@ -332,6 +331,7 @@ torch::Tensor isotonic_l2(torch::Tensor y) {
// Solves isotonic optimization with KL divergence using PAV.
// Formally, it solves argmin_{v_1 >= ... >= v_n} <e^{y-v}, 1> + <e^w, v>.
torch::Tensor isotonic_kl(torch::Tensor y, torch::Tensor w) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(y));
auto batch = y.size(0);
auto n = y.size(1);
auto sol = torch::zeros_like(y);
Expand All @@ -357,6 +357,7 @@ torch::Tensor isotonic_kl(torch::Tensor y, torch::Tensor w) {
}

torch::Tensor isotonic_l2_backward(torch::Tensor s, torch::Tensor sol, torch::Tensor grad_input) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(s));
auto batch = sol.size(0);
auto n = sol.size(1);
auto ret = torch::zeros_like(sol);
Expand All @@ -379,6 +380,7 @@ torch::Tensor isotonic_l2_backward(torch::Tensor s, torch::Tensor sol, torch::Te
}

torch::Tensor isotonic_kl_backward(torch::Tensor s, torch::Tensor sol, torch::Tensor grad_input) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(s));
auto batch = sol.size(0);
auto n = sol.size(1);
auto ret = torch::zeros_like(sol);
Expand Down

0 comments on commit 7ab8008

Please sign in to comment.