Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,18 @@ def test_read_file(self):
RuntimeError, "No such file or directory: 'tst'"):
read_file('tst')

def test_read_file_non_ascii(self):
with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
with open(fpath, 'wb') as f:
f.write(content)

data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8)
self.assertTrue(data.equal(expected))
os.unlink(fpath)

def test_write_file(self):
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
Expand All @@ -233,6 +245,18 @@ def test_write_file(self):
self.assertEqual(content, saved_content)
os.unlink(fpath)

def test_write_file_non_ascii(self):
with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)

with open(fpath, 'rb') as f:
saved_content = f.read()
self.assertEqual(content, saved_content)
os.unlink(fpath)


if __name__ == '__main__':
unittest.main()
65 changes: 52 additions & 13 deletions torchvision/csrc/cpu/image/read_write_file_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
#include "read_write_file_cpu.h"

// According to
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
// we should use _stat64 for 64-bit file size on Windows.
#ifdef _WIN32
#define VISION_STAT _stat64
#else
#define VISION_STAT stat
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>

std::wstring utf8_decode(const std::string& str) {
if (str.empty()) {
return std::wstring();
}
int size_needed = MultiByteToWideChar(
CP_UTF8, 0, str.c_str(), static_cast<int>(str.size()), NULL, 0);
TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode");
std::wstring wstrTo(size_needed, 0);
MultiByteToWideChar(
CP_UTF8,
0,
str.c_str(),
static_cast<int>(str.size()),
&wstrTo[0],
size_needed);
return wstrTo;
}
#endif

torch::Tensor read_file(std::string filename) {
struct VISION_STAT stat_buf;
int rc = VISION_STAT(filename.c_str(), &stat_buf);
torch::Tensor read_file(const std::string& filename) {
#ifdef _WIN32
// According to
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
// we should use struct __stat64 and _wstat64 for 64-bit file size on Windows.
struct __stat64 stat_buf;
auto fileW = utf8_decode(filename);
int rc = _wstat64(fileW.c_str(), &stat_buf);
#else
struct stat stat_buf;
int rc = stat(filename.c_str(), &stat_buf);
#endif
// errno is a variable defined in errno.h
TORCH_CHECK(
rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'");
Expand All @@ -21,9 +44,20 @@ torch::Tensor read_file(std::string filename) {
TORCH_CHECK(size > 0, "Expected a non empty file");

#ifdef _WIN32
auto data =
torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8)
.clone();
// TODO: Once torch::from_file handles UTF-8 paths correctly, we should move
// back to use the following implementation since it uses file mapping.
// auto data =
// torch::from_file(filename, /*shared=*/false, /*size=*/size,
// torch::kU8).clone()
FILE* infile = _wfopen(fileW.c_str(), L"rb");

TORCH_CHECK(infile != nullptr, "Error opening input file");

auto data = torch::empty({size}, torch::kU8);
auto dataBytes = data.data_ptr<uint8_t>();

fread(dataBytes, sizeof(uint8_t), size, infile);
fclose(infile);
#else
auto data =
torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8);
Expand All @@ -32,7 +66,7 @@ torch::Tensor read_file(std::string filename) {
return data;
}

void write_file(std::string filename, torch::Tensor& data) {
void write_file(const std::string& filename, torch::Tensor& data) {
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");

Expand All @@ -44,7 +78,12 @@ void write_file(std::string filename, torch::Tensor& data) {

auto fileBytes = data.data_ptr<uint8_t>();
auto fileCStr = filename.c_str();
#ifdef _WIN32
auto fileW = utf8_decode(filename);
FILE* outfile = _wfopen(fileW.c_str(), L"wb");
#else
FILE* outfile = fopen(fileCStr, "wb");
#endif

TORCH_CHECK(outfile != NULL, "Error opening output file");

Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/cpu/image/read_write_file_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
#include <sys/stat.h>
#include <torch/torch.h>

C10_EXPORT torch::Tensor read_file(std::string filename);
C10_EXPORT torch::Tensor read_file(const std::string& filename);

C10_EXPORT void write_file(std::string filename, torch::Tensor& data);
C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data);