Skip to content

Commit

Permalink
Update on "[quant][graphmode][fx] Scope support for call_method in Qu…
Browse files Browse the repository at this point in the history
…antizationTracer"


Summary:
Previously we did not set the qconfig for call_method node correctly since it requires us to know
the scope (module path and type of the module whose forward graph contains the node) of the node. This
PR modifies the QuantizationTracer to record the scope information and build a map from call_method
Node to (module_path, module_type), which will be used when we construct qconfig_map

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_qconfig_for_call_method
Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D25818132](https://our.internmc.facebook.com/intern/diff/D25818132)

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Jan 9, 2021
2 parents 46d4585 + 160b4be commit de3d47a
Show file tree
Hide file tree
Showing 56 changed files with 1,198 additions and 744 deletions.
8 changes: 5 additions & 3 deletions .circleci/scripts/windows_cuda_install.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#!/bin/bash
set -eux -o pipefail

if [[ "$CUDA_VERSION" =~ ^10.* ]]; then
cuda_major_version=${CUDA_VERSION%.*}

if [[ "$cuda_major_version" == "10" ]]; then
cuda_installer_name="cuda_10.1.243_426.00_win10"
msbuild_project_dir="CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions"
cuda_install_packages="nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1"
elif [[ "$CUDA_VERSION" =~ ^11.* ]]; then
elif [[ "$cuda_major_version" == "11" ]]; then
cuda_installer_name="cuda_11.1.0_456.43_win10"
msbuild_project_dir="visual_studio_integration/CUDAVisualStudioIntegration/extras/visual_studio_integration/MSBuildExtensions"
cuda_install_packages="nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1"
Expand All @@ -14,7 +16,7 @@ else
exit 1
fi

if [[ "$CUDA_VERSION" =~ ^11.* && "${JOB_EXECUTOR}" == "windows-with-nvidia-gpu" ]]; then
if [[ "$cuda_major_version" == "11" && "${JOB_EXECUTOR}" == "windows-with-nvidia-gpu" ]]; then
cuda_install_packages="${cuda_install_packages} Display.Driver"
fi

Expand Down
6 changes: 4 additions & 2 deletions .circleci/scripts/windows_cudnn_install.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/bin/bash
set -eux -o pipefail

if [[ "$CUDA_VERSION" =~ ^10.* ]]; then
cuda_major_version=${CUDA_VERSION%.*}

if [[ "$cuda_major_version" == "10" ]]; then
cudnn_installer_name="cudnn-${CUDA_VERSION}-windows10-x64-v7.6.4.38"
elif [[ "$CUDA_VERSION" =~ ^11.* ]]; then
elif [[ "$cuda_major_version" == "11" ]]; then
cudnn_installer_name="cudnn-${CUDA_VERSION}-windows-x64-v8.0.5.39"
else
echo "CUDNN for CUDA_VERSION $CUDA_VERSION is not supported yet"
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

.coverage
coverage.xml
.dmypy.json
.gradle
.hypothesis
.mypy_cache
Expand Down
160 changes: 160 additions & 0 deletions aten/src/ATen/CPUGeneratorImpl.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,49 @@
#include <ATen/CPUGeneratorImpl.h>
#include <ATen/Utils.h>
#include <ATen/core/MT19937RNGEngine.h>
#include <c10/util/C++17.h>
#include <algorithm>

namespace at {

namespace detail {

/**
* CPUGeneratorImplStateLegacy is a POD class needed for memcpys
* in torch.get_rng_state() and torch.set_rng_state().
* It is a legacy class and even though it is replaced with
* at::CPUGeneratorImpl, we need this class and some of its fields
* to support backward compatibility on loading checkpoints.
*/
struct CPUGeneratorImplStateLegacy {
/* The initial seed. */
uint64_t the_initial_seed;
int left; /* = 1; */
int seeded; /* = 0; */
uint64_t next;
uint64_t state[at::MERSENNE_STATE_N]; /* the array for the state vector */

/********************************/

/* For normal distribution */
double normal_x;
double normal_y;
double normal_rho;
int normal_is_valid; /* = 0; */
};

/**
* CPUGeneratorImplState is a POD class containing
* new data introduced in at::CPUGeneratorImpl and the legacy state. It is used
* as a helper for torch.get_rng_state() and torch.set_rng_state()
* functions.
*/
struct CPUGeneratorImplState {
CPUGeneratorImplStateLegacy legacy_pod;
float next_float_normal_sample;
bool is_next_float_normal_sample_valid;
};

/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
Expand Down Expand Up @@ -75,6 +113,128 @@ uint64_t CPUGeneratorImpl::seed() {
return random;
}

/**
* Sets the internal state of CPUGeneratorImpl. The new internal state
* must be a strided CPU byte tensor and of the same size as either
* CPUGeneratorImplStateLegacy (for legacy CPU generator state) or
* CPUGeneratorImplState (for new state).
*
* FIXME: Remove support of the legacy state in the future?
*/
void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
using detail::CPUGeneratorImplState;
using detail::CPUGeneratorImplStateLegacy;

static_assert(std::is_pod<CPUGeneratorImplStateLegacy>::value, "CPUGeneratorImplStateLegacy is not a PODType");
static_assert(std::is_pod<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");

static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy);
static const size_t size_current = sizeof(CPUGeneratorImplState);
static_assert(size_legacy != size_current, "CPUGeneratorImplStateLegacy and CPUGeneratorImplState can't be of the same size");

detail::check_rng_state(new_state);

at::mt19937 engine;
auto float_normal_sample = c10::optional<float>();
auto double_normal_sample = c10::optional<double>();

// Construct the state of at::CPUGeneratorImpl based on input byte tensor size.
CPUGeneratorImplStateLegacy* legacy_pod;
auto new_state_size = new_state.numel();
if (new_state_size == size_legacy) {
legacy_pod = (CPUGeneratorImplStateLegacy*)new_state.data();
// Note that in CPUGeneratorImplStateLegacy, we didn't have float version
// of normal sample and hence we leave the c10::optional<float> as is

// Update next_double_normal_sample.
// Note that CPUGeneratorImplStateLegacy stores two uniform values (normal_x, normal_y)
// and a rho value (normal_rho). These three values were redundant and in the new
// DistributionsHelper.h, we store the actual extra normal sample, rather than three
// intermediate values.
if (legacy_pod->normal_is_valid) {
auto r = legacy_pod->normal_rho;
auto theta = 2.0 * M_PI * legacy_pod->normal_x;
// we return the sin version of the normal sample when in caching mode
double_normal_sample = c10::optional<double>(r * ::sin(theta));
}
} else if (new_state_size == size_current) {
auto rng_state = (CPUGeneratorImplState*)new_state.data();
legacy_pod = &rng_state->legacy_pod;
// update next_float_normal_sample
if (rng_state->is_next_float_normal_sample_valid) {
float_normal_sample = c10::optional<float>(rng_state->next_float_normal_sample);
}

// Update next_double_normal_sample.
// Note that in getRNGState, we now return the actual normal sample in normal_y
// and if it's valid in normal_is_valid. The redundant normal_x and normal_rho
// are squashed to 0.0.
if (legacy_pod->normal_is_valid) {
double_normal_sample = c10::optional<double>(legacy_pod->normal_y);
}
} else {
AT_ERROR("Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy,
" or a CPUGeneratorImplState of size ", size_current,
" but found the input RNG state size to be ", new_state_size);
}

// construct engine_
// Note that CPUGeneratorImplStateLegacy stored a state array of 64 bit uints, whereas in our
// redefined mt19937, we have changed to a state array of 32 bit uints. Hence, we are
// doing a std::copy.
at::mt19937_data_pod rng_data;
std::copy(std::begin(legacy_pod->state), std::end(legacy_pod->state), rng_data.state_.begin());
rng_data.seed_ = legacy_pod->the_initial_seed;
rng_data.left_ = legacy_pod->left;
rng_data.seeded_ = legacy_pod->seeded;
rng_data.next_ = static_cast<uint32_t>(legacy_pod->next);
engine.set_data(rng_data);
TORCH_CHECK(engine.is_valid(), "Invalid mt19937 state");
this->engine_ = engine;
this->next_float_normal_sample_ = float_normal_sample;
this->next_double_normal_sample_ = double_normal_sample;
}

/**
* Gets the current internal state of CPUGeneratorImpl. The internal
* state is returned as a CPU byte tensor.
*/
c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
using detail::CPUGeneratorImplState;

static const size_t size = sizeof(CPUGeneratorImplState);
static_assert(std::is_pod<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");

auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto rng_state = state_tensor.data_ptr();

// accumulate generator data to be copied into byte tensor
auto accum_state = std::make_unique<CPUGeneratorImplState>();
auto rng_data = this->engine_.data();
accum_state->legacy_pod.the_initial_seed = rng_data.seed_;
accum_state->legacy_pod.left = rng_data.left_;
accum_state->legacy_pod.seeded = rng_data.seeded_;
accum_state->legacy_pod.next = rng_data.next_;
std::copy(rng_data.state_.begin(), rng_data.state_.end(), std::begin(accum_state->legacy_pod.state));
accum_state->legacy_pod.normal_x = 0.0; // we don't use it anymore and this is just a dummy
accum_state->legacy_pod.normal_rho = 0.0; // we don't use it anymore and this is just a dummy
accum_state->legacy_pod.normal_is_valid = false;
accum_state->legacy_pod.normal_y = 0.0;
accum_state->next_float_normal_sample = 0.0f;
accum_state->is_next_float_normal_sample_valid = false;
if (this->next_double_normal_sample_) {
accum_state->legacy_pod.normal_is_valid = true;
accum_state->legacy_pod.normal_y = *(this->next_double_normal_sample_);
}
if (this->next_float_normal_sample_) {
accum_state->is_next_float_normal_sample_valid = true;
accum_state->next_float_normal_sample = *(this->next_float_normal_sample_);
}

memcpy(rng_state, accum_state.get(), size);
return state_tensor.getIntrusivePtr();
}

/**
* Gets the DeviceType of CPUGeneratorImpl.
* Used for type checking during run time.
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/CPUGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
static DeviceType device_type();
uint32_t random();
uint64_t random64();
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/CUDAGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl {
void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
void set_philox_offset_per_thread(uint64_t offset);
uint64_t philox_offset_per_thread();
uint64_t philox_offset_per_thread() const;
void capture_prologue(int64_t* offset_extragraph);
uint64_t capture_epilogue();
PhiloxCudaState philox_cuda_state(uint64_t increment);
Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/core/Generator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>

namespace at {

void Generator::set_state(const at::Tensor& new_state) {
TORCH_CHECK(new_state.defined(), "Undefined tensor is not allowed");
this->impl_->set_state(*new_state.unsafeGetTensorImpl());
}

at::Tensor Generator::get_state() const {
return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
}

} // namespace at
28 changes: 28 additions & 0 deletions aten/src/ATen/core/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@

namespace at {

class Tensor;

struct TORCH_API Generator {
Generator() {}

Expand Down Expand Up @@ -96,6 +98,12 @@ struct TORCH_API Generator {

uint64_t seed() { return impl_->seed(); }

// Implementation not inlined to prevent cycle reference between
// `ATen/core/Generator.h` and `ATen/core/Tensor.h`
void set_state(const at::Tensor& new_state);

at::Tensor get_state() const;

std::mutex& mutex() {
return impl_->mutex_;
}
Expand Down Expand Up @@ -130,4 +138,24 @@ Generator make_generator(Args&&... args) {
return Generator(c10::make_intrusive<Impl>(std::forward<Args>(args)...));
}

namespace detail {

/**
* Helper function for checking the validity of new random generator
* state. Right now following conditions are checked:
*
* - The new state tensor must be a torch.ByteTensor
* - Data of the new state tensor must be contiguous
*/
static inline void check_rng_state(const c10::TensorImpl& new_state) {
TORCH_CHECK_TYPE(
new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
"RNG state must be a torch.ByteTensor"
);

TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous");
}

} // namespace detail

} // namespace at
17 changes: 13 additions & 4 deletions aten/src/ATen/core/TransformationHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ C10_HOST_DEVICE inline T uniform_int_full_range(V val) {

/**
* A transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`.
* In order to prevent compiler warnings reported in GitHub issue 46391, T can't be float or double
* in this overloaded version
*/
template <typename T, typename V>
C10_HOST_DEVICE inline T uniform_int(V val) {
C10_HOST_DEVICE inline typename std::enable_if<!(std::is_floating_point<T>::value), T>::type uniform_int(V val) {
if (std::is_same<T, bool>::value) {
return static_cast<bool>(val & 1);
} else if (std::is_same<T, double>::value) {
return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
} else if (std::is_same<T, int64_t>::value) {
return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
} else if (std::is_floating_point<T>::value || std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) {
} else if (std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) {
return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
} else if (std::is_integral<T>::value) {
return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
Expand All @@ -68,6 +68,15 @@ C10_HOST_DEVICE inline T uniform_int(V val) {
}
}

/**
* An overloaded transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`,
* added to fix compiler warnings reported in GitHub issue 46391. T is either float or double in this version.
*/
template<typename T, typename V>
C10_HOST_DEVICE inline typename std::enable_if<std::is_floating_point<T>::value, T>::type uniform_int(V val) {
return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
}

template <typename T, typename V>
C10_HOST_DEVICE inline dist_acctype<T> uniform_real(V val, T from, T to) {
constexpr auto MASK = static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);
Expand Down
26 changes: 26 additions & 0 deletions aten/src/ATen/core/jit_type_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ struct TORCH_API Type : std::enable_shared_from_this<Type> {
return nullptr;
}
template <typename T>
T* castRaw() {
if (T::Kind == kind()) {
return static_cast<T*>(this);
}
return nullptr;
}
template <typename T>
const T* castRaw() const {
if (T::Kind == kind()) {
return static_cast<T*>(this);
}
return nullptr;
}
template <typename T>
std::shared_ptr<T> expect() {
auto r = cast<T>();
AT_ASSERT(r);
Expand All @@ -163,6 +177,18 @@ struct TORCH_API Type : std::enable_shared_from_this<Type> {
AT_ASSERT(r);
return r;
}
template <typename T>
T& expectRef() {
auto* r = castRaw<T>();
AT_ASSERT(r);
return *r;
}
template <typename T>
const T& expectRef() const {
auto* r = castRaw<const T>();
AT_ASSERT(r);
return *r;
}
virtual ~Type() = default;
virtual bool hasFreeVariables() const {
return false;
Expand Down

0 comments on commit de3d47a

Please sign in to comment.