From 179d2b288c4398d4bd1d9187b36ae978bc308cc9 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 28 Oct 2020 11:19:59 -0700 Subject: [PATCH] Fix interval midpoint calculation in vulkan (#46839) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46839 Interval midpoint calculations can overflow (integers). This fixes such an instance. Test Plan: Standard test rig. Reviewed By: drdarshan Differential Revision: D24392545 fbshipit-source-id: 84c81802165bb8084e2d54c9f3755f39143a5b00 --- aten/src/ATen/native/vulkan/api/vk_mem_alloc.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h index fdeadf9cdbfa..b468a1c05c6d 100644 --- a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h +++ b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h @@ -4594,7 +4594,7 @@ static IterT VmaBinaryFindFirstNotLess(IterT beg, IterT end, const KeyT &key, co size_t down = 0, up = (end - beg); while(down < up) { - const size_t mid = (down + up) / 2; + const size_t mid = down + (up - down) / 2; //Overflow-safe midpoint calculation if(cmp(*(beg+mid), key)) { down = mid + 1;