Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
585 changes: 0 additions & 585 deletions aten/src/ATen/native/QuantizedLinear.cpp

This file was deleted.

203 changes: 0 additions & 203 deletions aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,12 @@
#include <ATen/ops/cat.h>
#include <ATen/ops/cudnn_is_acceptable.h>
#include <ATen/ops/dropout.h>
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
#include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
#include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
#include <ATen/ops/gru_cell_native.h>
#include <ATen/ops/gru_native.h>
#include <ATen/ops/linear.h>
#include <ATen/ops/lstm_cell_native.h>
#include <ATen/ops/lstm_native.h>
#include <ATen/ops/matmul.h>
#include <ATen/ops/quantized_gru_cell_native.h>
#include <ATen/ops/quantized_lstm_cell_native.h>
#include <ATen/ops/quantized_rnn_relu_cell_native.h>
#include <ATen/ops/quantized_rnn_tanh_cell_native.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/rnn_relu_cell_native.h>
#include <ATen/ops/rnn_relu_native.h>
Expand Down Expand Up @@ -208,158 +201,6 @@ struct CellParams : public CellParamsBase {
}
};

c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
const at::Tensor& w_ih,
const at::Tensor& w_hh,
at::Tensor bias_ih,
at::Tensor bias_hh);

struct QuantizedCellParams : public CellParamsBase {
QuantizedCellParams(
Tensor _w_ih,
Tensor _w_hh,
Tensor _b_ih,
Tensor _b_hh,
Tensor _packed_ih,
Tensor _packed_hh,
Tensor _col_offsets_ih,
Tensor _col_offsets_hh,
Scalar _scale_ih,
Scalar _scale_hh,
Scalar _zero_point_ih,
Scalar _zero_point_hh)
: w_ih(std::move(_w_ih)),
w_hh(std::move(_w_hh)),
b_ih_(std::move(_b_ih)),
b_hh_(std::move(_b_hh)),
packed_ih(std::move(_packed_ih)),
packed_hh(std::move(_packed_hh)),
col_offsets_ih(std::move(_col_offsets_ih)),
col_offsets_hh(std::move(_col_offsets_hh)),
scale_ih(std::move(_scale_ih)),
scale_hh(std::move(_scale_hh)),
zero_point_ih(std::move(_zero_point_ih)),
zero_point_hh(std::move(_zero_point_hh)) {}

const Tensor w_ih;
const Tensor w_hh;
const Tensor b_ih_;
const Tensor b_hh_;
const Tensor packed_ih;
const Tensor packed_hh;
const Tensor col_offsets_ih;
const Tensor col_offsets_hh;
const Scalar scale_ih;
const Scalar scale_hh;
const Scalar zero_point_ih;
const Scalar zero_point_hh;

Tensor matmul_ih(const Tensor& input) const override {
TORCH_CHECK(false, "matmul is not supported with quantized cell params");
}
Tensor matmul_hh(const Tensor& h) const override {
TORCH_CHECK(false, "matmul is not supported with quantized cell params");
}
Tensor linear_ih(const Tensor& input) const override {
return at::fbgemm_linear_int8_weight_fp32_activation(
input, w_ih, packed_ih, col_offsets_ih, scale_ih, zero_point_ih, b_ih_);
}
Tensor linear_hh(const Tensor& h) const override {
return at::fbgemm_linear_int8_weight_fp32_activation(
h, w_hh, packed_hh, col_offsets_hh, scale_hh, zero_point_hh, b_hh_);
}
const Tensor& b_ih() const override {
return b_ih_;
}
const Tensor& b_hh() const override {
return b_hh_;
}
CellParamsSerializationType __getstate__() const override {
std::vector<at::Tensor> tensors_to_serialize = {
w_ih, w_hh, b_ih_, b_hh_, col_offsets_ih, col_offsets_hh};
std::vector<double> doubles_to_serialize = {scale_ih.toDouble(),
scale_hh.toDouble()};
std::vector<int64_t> longs_to_serialize = {zero_point_ih.toLong(),
zero_point_hh.toLong()};
return CellParamsSerializationType(
"quantized",
std::move(tensors_to_serialize),
std::move(doubles_to_serialize),
std::move(longs_to_serialize),
{});
}
static c10::intrusive_ptr<CellParamsBase> __setstate__(
CellParamsSerializationType state) {
std::vector<at::Tensor> tensors;
std::vector<double> doubles;
std::vector<int64_t> longs;
std::tie(std::ignore, tensors, doubles, longs, std::ignore) =
std::move(state);
TORCH_INTERNAL_ASSERT(tensors.size() == 6);
TORCH_INTERNAL_ASSERT(doubles.size() == 2);
TORCH_INTERNAL_ASSERT(longs.size() == 2);

at::Tensor qw_ih = std::move(tensors[0]), qw_hh = std::move(tensors[1]),
b_ih = std::move(tensors[2]), b_hh = std::move(tensors[3]),
col_offsets_ih = std::move(tensors[4]),
col_offsets_hh = std::move(tensors[5]);
double scale_ih = doubles[0], scale_hh = doubles[1];
int64_t zero_point_ih = longs[0], zero_point_hh = longs[1];

at::Tensor packed_ih = at::native::fbgemm_pack_quantized_matrix(qw_ih);
at::Tensor packed_hh = at::native::fbgemm_pack_quantized_matrix(qw_hh);

return c10::make_intrusive<QuantizedCellParams>(
/*w_ih=*/std::move(qw_ih),
/*w_hh=*/std::move(qw_hh),
/*b_ih_=*/std::move(b_ih),
/*b_hh_=*/std::move(b_hh),
/*packed_ih=*/std::move(packed_ih),
/*packed_hh=*/std::move(packed_hh),
/*col_offsets_ih=*/std::move(col_offsets_ih),
/*col_offsets_hh=*/std::move(col_offsets_hh),
/*scale_ih=*/scale_ih,
/*scale_hh=*/scale_hh,
/*zero_point_ih=*/zero_point_ih,
/*zero_point_hh=*/zero_point_hh);
}
};

c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
const at::Tensor& w_ih,
const at::Tensor& w_hh,
at::Tensor b_ih,
at::Tensor b_hh) {
auto make_vals = [&](const at::Tensor& W) {
auto params = at::native::fbgemm_linear_quantize_weight(W);
at::Tensor packed_weight =
at::native::fbgemm_pack_quantized_matrix(std::get<0>(params));
return std::tuple_cat(
std::make_tuple(std::move(packed_weight)), std::move(params));
};

at::Tensor qw_ih, qw_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh;
at::Scalar scale_ih, scale_hh, zero_point_ih, zero_point_hh;

std::tie(packed_ih, qw_ih, col_offsets_ih, scale_ih, zero_point_ih) =
make_vals(w_ih);
std::tie(packed_hh, qw_hh, col_offsets_hh, scale_hh, zero_point_hh) =
make_vals(w_hh);

return c10::make_intrusive<QuantizedCellParams>(
/*qw_ih=*/std::move(qw_ih),
/*qw_hh=*/std::move(qw_hh),
/*b_ih=*/std::move(b_ih),
/*b_hh=*/std::move(b_hh),
/*packed_ih=*/std::move(packed_ih),
/*packed_hh=*/std::move(packed_hh),
/*col_offsets_ih=*/std::move(col_offsets_ih),
/*col_offsets_hh=*/std::move(col_offsets_hh),
/*scale_ih=*/std::move(scale_ih),
/*scale_hh=*/std::move(scale_hh),
/*zero_point_ih=*/std::move(zero_point_ih),
/*zero_point_hh=*/std::move(zero_point_hh));
}

// QuantizedCellParams vs. QuantizedCellParamsDynamic
//
Expand Down Expand Up @@ -536,7 +377,6 @@ static std::unordered_map<
std::string,
c10::intrusive_ptr<CellParamsBase> (*)(CellParamsSerializationType)>
cell_params_deserializers = {
{"quantized", &QuantizedCellParams::__setstate__},
{"quantized_dynamic", &QuantizedCellParamsDynamic::__setstate__},
{"quantized_fp16", &QuantizedCellParamsFP16::__setstate__}};

Expand Down Expand Up @@ -1841,38 +1681,6 @@ static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data_legacy(
"using the newer definitions in torch.jit.quantized");
}

#define DEFINE_QUANTIZED_RNN_CELL(name, hx_type, cell_type, return_type, prepare_hx_fn) \
return_type name( \
const Tensor& input, \
hx_type hx, \
const Tensor& w_ih, \
const Tensor& w_hh, \
const Tensor& b_ih, \
const Tensor& b_hh, \
const Tensor& packed_ih, \
const Tensor& packed_hh, \
const Tensor& col_offsets_ih, \
const Tensor& col_offsets_hh, \
const Scalar& scale_ih, \
const Scalar& scale_hh, \
const Scalar& zero_point_ih, \
const Scalar& zero_point_hh) { \
QuantizedCellParams params( \
w_ih, \
w_hh, \
b_ih, \
b_hh, \
packed_ih, \
packed_hh, \
col_offsets_ih, \
col_offsets_hh, \
scale_ih, \
scale_hh, \
zero_point_ih, \
zero_point_hh); \
return cell_type{}( \
input, prepare_hx_fn(hx), params); \
}
// Set reduced range to be True for all RNN Cells by default. This flag is used only for FBGEMM kernels
// QNNPACK does not reduce range for activations
#define DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(name, hx_type, cell_type, return_type, prepare_hx_fn) \
Expand All @@ -1895,7 +1703,6 @@ return_type name( \
}

// Quantized LSTM cell
using quantized_lstm_cell_type = LSTMCell<QuantizedCellParams>;
using quantized_lstm_return_type = std::tuple<Tensor, Tensor>;
static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
return std::make_tuple(hx[0], hx[1]);
Expand All @@ -1904,7 +1711,6 @@ static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
// Quantized LSTM cell
using quantized_lstm_cell_dynamic_type = LSTMCell<QuantizedCellParamsDynamic>;

DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);

static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);

Expand All @@ -1915,22 +1721,15 @@ static simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
}

// Quantized GRU cell
using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
using quantized_gru_cell_dynamic_type = GRUCell<QuantizedCellParamsDynamic>;

DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);

static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx);

// Quantized RNN w/ ReLU cell
using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
using quantized_rnn_relu_cell_dynamic_type = SimpleCell<relu_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx);

// Quantized RNN w/ tanh cell
using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
using quantized_rnn_tanh_cell_dynamic_type = SimpleCell<tanh_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx);

Expand Down Expand Up @@ -1972,7 +1771,6 @@ TORCH_LIBRARY_FRAGMENT(aten, m) {
TORCH_LIBRARY_FRAGMENT(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
Expand All @@ -1992,7 +1790,6 @@ TORCH_LIBRARY_IMPL(aten, CPU, m) {

TORCH_LIBRARY_IMPL(quantized, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params_dynamic"), TORCH_FN(make_quantized_cell_params_dynamic));
m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params"), TORCH_FN(make_quantized_cell_params));
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_lstm_cell_dynamic"), TORCH_FN(quantized_lstm_cell_dynamic));
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_gru_cell_dynamic"), TORCH_FN(quantized_gru_cell_dynamic));
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_rnn_relu_cell_dynamic"), TORCH_FN(quantized_rnn_relu_cell_dynamic));
Expand Down
25 changes: 0 additions & 25 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3277,22 +3277,6 @@
dispatch:
CUDA: _mixed_dtypes_linear

- func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor

- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor

- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)

- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor

- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor

- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor

- func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor

- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor

- func: ldexp.Tensor(Tensor self, Tensor other) -> Tensor
variants: function, method

Expand Down Expand Up @@ -7603,15 +7587,6 @@
# - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
#

# Quantized RNN cells
- func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)

- func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor

- func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor

- func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor

# PackedSequence utilities
- func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
dispatch:
Expand Down
1 change: 0 additions & 1 deletion build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,6 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/PointwiseOps.cpp",
"aten/src/ATen/native/Pooling.cpp",
"aten/src/ATen/native/Pow.cpp",
"aten/src/ATen/native/QuantizedLinear.cpp",
"aten/src/ATen/native/RNN.cpp",
"aten/src/ATen/native/RangeFactories.cpp",
"aten/src/ATen/native/ReduceAllOps.cpp",
Expand Down
1 change: 0 additions & 1 deletion caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,6 @@ set(ATen_CPU_INCLUDE
${CMAKE_BINARY_DIR}/aten/src)

if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/QuantizedLinear.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/RNN.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/qlinear_unpack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,18 @@
("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)),
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_int8_weight_fp32_activation", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_int8_weight", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_quantize_weight", datetime.date(2023, 12, 31)),
("aten::fbgemm_pack_gemm_matrix_fp16", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_fp16_weight_fp32_activation", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_fp16_weight", datetime.date(2023, 12, 31)),
("aten::fbgemm_pack_quantized_matrix", datetime.date(2023, 12, 31)),
("aten::quantized_lstm_cell", datetime.date(2023, 12, 31)),
("aten::quantized_gru_cell", datetime.date(2023, 12, 31)),
("aten::quantized_rnn_relu_cell", datetime.date(2023, 12, 31)),
("aten::quantized_rnn_tanh_cell", datetime.date(2023, 12, 31)),
("quantized::make_quantized_cell_params", datetime.date(2023, 12, 31)),
]

ALLOW_LIST_COMPILED = [
Expand Down
13 changes: 0 additions & 13 deletions test/jit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
from torch.testing._internal.common_utils import slowTest, suppress_warnings
from torch.testing._internal.common_quantization import skipIfNoFBGEMM

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
Expand Down Expand Up @@ -407,12 +406,6 @@ class Config:
def test_snli(self):
self._test_snli(self, device='cpu')

@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_snli_quantized(self):
self._test_snli(self, device='cpu', quantized=True)

@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_snli_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
Expand Down Expand Up @@ -553,12 +546,6 @@ def forward(self, x):
def test_vae(self):
self._test_vae(self, device='cpu')

@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_vae_quantized(self):
self._test_vae(self, device='cpu', quantized=True)

@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_vae_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
Expand Down
Loading