Skip to content

Commit

Permalink
Merge branch 'pytorch:master' into debug_positive_definite_constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
nonconvexopt committed Nov 30, 2021
2 parents ee046f4 + 0cdeb58 commit 67b3448
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 7 deletions.
1 change: 1 addition & 0 deletions test/cpp/lazy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ set(LAZY_TEST_SRCS
${LAZY_TEST_ROOT}/test_misc.cpp
${LAZY_TEST_ROOT}/test_permutation_util.cpp
${LAZY_TEST_ROOT}/test_shape.cpp
${LAZY_TEST_ROOT}/test_util.cpp
)

add_executable(test_lazy
Expand Down
14 changes: 7 additions & 7 deletions test/cpp/lazy/test_misc.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#include <gtest/gtest.h>
#include <string>

#include "torch/csrc/lazy/core/hash.h"
#include "c10/util/int128.h"
#include <torch/csrc/lazy/core/hash.h>
#include <c10/util/int128.h>

namespace torch {
namespace lazy {

template <typename T>
void test_hash_repeatable_sensitive(T example_a, T example_b) {
void test_hash_repeatable_sensitive(const T& example_a, const T& example_b) {
// repeatable
EXPECT_EQ(Hash(example_a), Hash(example_a));
EXPECT_EQ(MHash(example_a), MHash(example_a));
Expand Down Expand Up @@ -67,10 +67,10 @@ TEST(HashTest, Sanity) {
c10::optional<std::string>(c10::nullopt));

// Containers
test_hash_repeatable_sensitive(
std::vector<int32_t>({0, 1, 1, 2, 3, 5, 8}),
std::vector<int32_t>({1, 1, 2, 3, 5, 8, 12}));

auto a = std::vector<int32_t>({0, 1, 1, 2, 3, 5, 8});
auto b = std::vector<int32_t>({1, 1, 2, 3, 5, 8, 12});
test_hash_repeatable_sensitive(a, b);
test_hash_repeatable_sensitive(c10::ArrayRef<int32_t>(a), c10::ArrayRef<int32_t>(b));
}

} // namespace lazy
Expand Down
73 changes: 73 additions & 0 deletions test/cpp/lazy/test_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include <gtest/gtest.h>

#include <exception>

#include <torch/csrc/lazy/core/util.h>

namespace torch {
namespace lazy {

TEST(UtilTest, ExceptionCleanup) {
std::exception_ptr exception;
EXPECT_EQ(exception, nullptr);

{
ExceptionCleanup cleanup([&](std::exception_ptr&& e) {
exception = std::move(e);
});

cleanup.SetStatus(std::make_exception_ptr(std::runtime_error("Oops!")));
}
EXPECT_NE(exception, nullptr);

try {
std::rethrow_exception(exception);
} catch(const std::exception& e) {
EXPECT_STREQ(e.what(), "Oops!");
}

exception = nullptr;
{
ExceptionCleanup cleanup([&](std::exception_ptr&& e) {
exception = std::move(e);
});

cleanup.SetStatus(std::make_exception_ptr(std::runtime_error("")));
cleanup.Release();
}
EXPECT_EQ(exception, nullptr);
}

TEST(UtilTest, MaybeRef) {
std::string storage("String storage");
MaybeRef<std::string> refStorage(storage);
EXPECT_FALSE(refStorage.IsStored());
EXPECT_EQ(*refStorage, storage);

MaybeRef<std::string> effStorage(std::string("Vanishing"));
EXPECT_TRUE(effStorage.IsStored());
EXPECT_EQ(*effStorage, "Vanishing");
}

TEST(UtilTest, Iota) {
auto result = Iota<int>(0);
EXPECT_TRUE(result.empty());

result = Iota<int>(1);
EXPECT_EQ(result.size(), 1);
EXPECT_EQ(result[0], 0);

result = Iota<int>(2);
EXPECT_EQ(result.size(), 2);
EXPECT_EQ(result[0], 0);
EXPECT_EQ(result[1], 1);

result = Iota<int>(3, 1, 3);
EXPECT_EQ(result.size(), 3);
EXPECT_EQ(result[0], 1);
EXPECT_EQ(result[1], 4);
EXPECT_EQ(result[2], 7);
}

} // namespace lazy
} // namespace torch
5 changes: 5 additions & 0 deletions torch/csrc/lazy/core/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ static inline hash_t Hash(const hash_t& value) {
return value;
}

template <typename T>
torch::lazy::hash_t Hash(c10::ArrayRef<T> values) {
return torch::lazy::ContainerHash(values);
}

template <typename T>
hash_t ContainerHash(const T& values) {
hash_t h(static_cast<uint64_t>(0x85ebca77c2b2ae63));
Expand Down
109 changes: 109 additions & 0 deletions torch/csrc/lazy/core/util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/**
* Most of the utils in this file is adapted from PyTorch/XLA
* https://github.com/pytorch/xla/blob/master/third_party/xla_client/util.h
*/

#pragma once

#include <exception>
#include <functional>
#include <vector>

#include <c10/util/Optional.h>

namespace torch {
namespace lazy {

// Similar to c10::scope_exit but with a status.
// TODO(alanwaketan): Consolidate it with c10::scope_exit.
template <typename T>
class Cleanup {
public:
using StatusType = T;

explicit Cleanup(std::function<void(StatusType&&)>&& func)
: func_(std::move(func)) {}
Cleanup(Cleanup&& ref) noexcept
: func_(std::move(ref.func_)), status_(std::move(ref.status_)) {}
Cleanup(const Cleanup&) = delete;

~Cleanup() {
if (func_ != nullptr) {
func_(std::move(status_));
}
}

Cleanup& operator=(const Cleanup&) = delete;

Cleanup& operator=(Cleanup&& ref) {
if (this != &ref) {
func_ = std::move(ref.func_);
status_ = std::move(ref.status_);
}
return *this;
}

void Release() { func_ = nullptr; }

void SetStatus(StatusType&& status) { status_ = std::move(status); }

const StatusType& GetStatus() const { return status_; }

private:
std::function<void(StatusType&&)> func_;
StatusType status_;
};

using ExceptionCleanup = Cleanup<std::exception_ptr>;

// Allows APIs which might return const references and values, to not be forced
// to return values in the signature.
// TODO(alanwaketan): This is clever, but is there really no std or c10 supports?
// Needs more investigations.
template <typename T>
class MaybeRef {
public:
/* implicit */ MaybeRef(const T& ref) : ref_(ref) {}
/* implicit */ MaybeRef(T&& value) : storage_(std::move(value)), ref_(*storage_) {}

const T& Get() const { return ref_; }
const T& operator*() const { return Get(); }
operator const T&() const { return Get(); }

bool IsStored() const { return storage_.has_value(); }

private:
c10::optional<T> storage_;
const T& ref_;
};

template <typename T>
std::vector<T> Iota(size_t size, T init = 0, T incr = 1) {
std::vector<T> result(size);
T value = init;
for (size_t i = 0; i < size; ++i, value += incr) {
result[i] = value;
}
return result;
}

template <typename T, typename S>
std::vector<T> ToVector(const S& input) {
return std::vector<T>(input.begin(), input.end());
}

template<typename T>
c10::optional<std::vector<T>> ToOptionalVector(c10::optional<c10::ArrayRef<T>> arrayRef) {
if (arrayRef) {
return arrayRef->vec();
}
return c10::nullopt;
}

template <typename T>
typename std::underlying_type<T>::type GetEnumValue(T value) {
return static_cast<typename std::underlying_type<T>::type>(value);
}

} // namespace lazy
} // namespace torch

0 comments on commit 67b3448

Please sign in to comment.