diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py index fef9bc6dc9e..d987db6ddeb 100644 --- a/test/test_video_gpu_decoder.py +++ b/test/test_video_gpu_decoder.py @@ -30,7 +30,7 @@ class TestVideoGPUDecoder: ) def test_frame_reading(self, video_file): full_path = os.path.join(VIDEO_DIR, video_file) - decoder = VideoReader(full_path, device="cuda:0") + decoder = VideoReader(full_path, device="cuda") with av.open(full_path) as container: for av_frame in container.decode(container.streams.video[0]): av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) @@ -54,7 +54,7 @@ def test_frame_reading(self, video_file): ], ) def test_seek_reading(self, keyframes, full_path, duration): - decoder = VideoReader(full_path, device="cuda:0") + decoder = VideoReader(full_path, device="cuda") time = duration / 2 decoder.seek(time, keyframes_only=keyframes) with av.open(full_path) as container: @@ -80,7 +80,7 @@ def test_seek_reading(self, keyframes, full_path, duration): ) def test_metadata(self, video_file): full_path = os.path.join(VIDEO_DIR, video_file) - decoder = VideoReader(full_path, device="cuda:0") + decoder = VideoReader(full_path, device="cuda") video_metadata = decoder.get_metadata()["video"] with av.open(full_path) as container: video = container.streams.video[0] diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index 215e0ede9c8..d26e80da2d8 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -3,9 +3,10 @@ /* Set cuda device, create cuda context and initialise the demuxer and decoder. */ -GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) - : demuxer(src_file.c_str()), device(dev) { - at::cuda::CUDAGuard device_guard(device); +GPUDecoder::GPUDecoder(std::string src_file, torch::Device dev) + : demuxer(src_file.c_str()) { + at::cuda::CUDAGuard device_guard(dev); + device = device_guard.current_device().index(); check_for_cuda_errors( cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__); decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec())); @@ -58,7 +59,7 @@ c10::Dict> GPUDecoder:: TORCH_LIBRARY(torchvision, m) { m.class_("GPUDecoder") - .def(torch::init()) + .def(torch::init()) .def("seek", &GPUDecoder::seek) .def("get_metadata", &GPUDecoder::get_metadata) .def("next", &GPUDecoder::decode); diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h index 95599f646ed..22bf680a982 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -5,7 +5,7 @@ class GPUDecoder : public torch::CustomClassHolder { public: - GPUDecoder(std::string, int64_t); + GPUDecoder(std::string, torch::Device); ~GPUDecoder(); torch::Tensor decode(); void seek(double, bool); diff --git a/torchvision/io/video_reader.py b/torchvision/io/video_reader.py index 545b5712e93..afd7fdf4be6 100644 --- a/torchvision/io/video_reader.py +++ b/torchvision/io/video_reader.py @@ -84,6 +84,7 @@ class VideoReader: will depend on the version of FFMPEG codecs supported. device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``. + To use GPU decoding, pass ``device="cuda"``. """ @@ -95,9 +96,7 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0, devic if not _HAS_GPU_VIDEO_DECODER: raise RuntimeError("Not compiled with GPU decoder support.") self.is_cuda = True - if device.index is None: - raise RuntimeError("Invalid cuda device!") - self._c = torch.classes.torchvision.GPUDecoder(path, device.index) + self._c = torch.classes.torchvision.GPUDecoder(path, device) return if not _has_video_opt(): raise RuntimeError(