Skip to content

Commit

Permalink
S390x big endian fixes (#8149)
Browse files Browse the repository at this point in the history
Fixes for multiple tests on s390x
  • Loading branch information
AlekseiNikiforovIBM committed Dec 6, 2023
1 parent 526ec93 commit e12d200
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
9 changes: 5 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
import re
import sys
from functools import partial

import numpy as np
Expand Down Expand Up @@ -614,7 +615,7 @@ def _get_1_channel_tensor_various_types():

img_data_short = torch.ShortTensor(1, 4, 4).random_()
expected_output = img_data_short.numpy()
yield img_data_short, expected_output, "I;16"
yield img_data_short, expected_output, "I;16" if sys.byteorder == "little" else "I;16B"

img_data_int = torch.IntTensor(1, 4, 4).random_()
expected_output = img_data_int.numpy()
Expand All @@ -631,7 +632,7 @@ def _get_2d_tensor_various_types():

img_data_short = torch.ShortTensor(4, 4).random_()
expected_output = img_data_short.numpy()
yield img_data_short, expected_output, "I;16"
yield img_data_short, expected_output, "I;16" if sys.byteorder == "little" else "I;16B"

img_data_int = torch.IntTensor(4, 4).random_()
expected_output = img_data_int.numpy()
Expand Down Expand Up @@ -662,7 +663,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
[
(torch.Tensor(4, 4, 1).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"),
(torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16" if sys.byteorder == "little" else "I;16B"),
(torch.IntTensor(4, 4, 1).random_().numpy(), "I"),
],
)
Expand Down Expand Up @@ -744,7 +745,7 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe
[
(torch.Tensor(4, 4).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4).random_().numpy(), "I;16"),
(torch.ShortTensor(4, 4).random_().numpy(), "I;16" if sys.byteorder == "little" else "I;16B"),
(torch.IntTensor(4, 4).random_().numpy(), "I"),
],
)
Expand Down
16 changes: 13 additions & 3 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,15 +510,25 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
# read
with open(path, "rb") as f:
data = f.read()

# parse
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
if sys.byteorder == "little":
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
else:
nd = get_int(data[0:1])
ty = get_int(data[1:2]) + get_int(data[2:3]) * 256 + get_int(data[3:4]) * 256 * 256

assert 1 <= nd <= 3
assert 8 <= ty <= 14
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]

if sys.byteorder == "big":
for i in range(len(s)):
s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)

parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))

# The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import numbers
import sys
import warnings
from enum import Enum
from typing import Any, List, Optional, Tuple, Union
Expand Down Expand Up @@ -162,7 +163,7 @@ def to_tensor(pic) -> Tensor:
return torch.from_numpy(nppic).to(dtype=default_float_dtype)

# handle PIL Image
mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.int16, "F": np.float32}
img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))

if pic.mode == "1":
Expand Down Expand Up @@ -285,7 +286,7 @@ def to_pil_image(pic, mode=None):
if npimg.dtype == np.uint8:
expected_mode = "L"
elif npimg.dtype == np.int16:
expected_mode = "I;16"
expected_mode = "I;16" if sys.byteorder == "little" else "I;16B"
elif npimg.dtype == np.int32:
expected_mode = "I"
elif npimg.dtype == np.float32:
Expand Down

0 comments on commit e12d200

Please sign in to comment.