Skip to content

Commit

Permalink
Add support for device type in random data getters (#132)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #132

Reviewed By: JKSenthil

Differential Revision: D43756530

fbshipit-source-id: 6c6709d37a48949ddd0ab5de95e3e5ddff8851fc
  • Loading branch information
bobakfb authored and facebook-github-bot committed Mar 7, 2023
1 parent 53b16ad commit b99b753
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
27 changes: 27 additions & 0 deletions tests/utils/test_random_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class RandomDataTest(unittest.TestCase):
cuda_avail: bool = torch.cuda.is_available()

def test_get_rand_data_binary(self) -> None:
input, targets = get_rand_data_binary(num_updates=2, num_tasks=5, batch_size=10)
self.assertEqual(input.size(), targets.size())
Expand All @@ -23,3 +25,28 @@ def test_get_rand_data_multiclass(self) -> None:
)
self.assertEqual(input.size(), torch.Size([2, 10, 5]))
self.assertTrue(torch.all(torch.lt(targets, 5)))

@unittest.skipUnless(
condition=cuda_avail, reason="This test needs a GPU host to run."
)
def test_get_rand_data_binary_GPU(self) -> None:
device = torch.device("cuda")
input, targets = get_rand_data_binary(
num_updates=2, num_tasks=5, batch_size=10, device=device
)
self.assertEqual(input.size(), targets.size())
self.assertTrue(input.is_cuda)
self.assertTrue(targets.is_cuda)

@unittest.skipUnless(
condition=cuda_avail, reason="This test needs a GPU host to run."
)
def test_get_rand_data_multiclass_GPU(self) -> None:
device = torch.device("cuda")
input, targets = get_rand_data_multiclass(
num_updates=2, num_classes=5, batch_size=10, device=device
)
self.assertEqual(input.size(), torch.Size([2, 10, 5]))
self.assertTrue(torch.all(torch.lt(targets, 5)))
self.assertTrue(input.is_cuda)
self.assertTrue(targets.is_cuda)
28 changes: 21 additions & 7 deletions torcheval/utils/random_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple
from typing import Optional, Tuple

import torch


def get_rand_data_binary(
num_updates: int, num_tasks: int, batch_size: int
num_updates: int,
num_tasks: int,
batch_size: int,
device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generates a random binary dataset.
Expand All @@ -24,13 +27,20 @@ def get_rand_data_binary(
torch.Tensor: random feature data
torch.Tensor: random targets
"""
input = torch.rand(size=[num_updates, num_tasks, batch_size])
targets = torch.randint(low=0, high=2, size=[num_updates, num_tasks, batch_size])
if device is None:
device = torch.device("cpu")
input = torch.rand(size=[num_updates, num_tasks, batch_size]).to(device)
targets = torch.randint(
low=0, high=2, size=[num_updates, num_tasks, batch_size]
).to(device)
return input, targets


def get_rand_data_multiclass(
num_updates: int, num_classes: int, batch_size: int
num_updates: int,
num_classes: int,
batch_size: int,
device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generates a random multiclass dataset.
Expand All @@ -44,6 +54,10 @@ def get_rand_data_multiclass(
torch.Tensor: random feature data
torch.Tensor: random targets
"""
input = torch.rand(size=[num_updates, batch_size, num_classes])
targets = torch.randint(low=0, high=num_classes, size=[num_updates, batch_size])
if device is None:
device = torch.device("cpu")
input = torch.rand(size=[num_updates, batch_size, num_classes]).to(device)
targets = torch.randint(low=0, high=num_classes, size=[num_updates, batch_size]).to(
device
)
return input, targets

0 comments on commit b99b753

Please sign in to comment.