diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index ce38c27c68f..8a36c1b8d04 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -1,6 +1,5 @@ import itertools import os -import re from abc import ABC, abstractmethod from glob import glob from pathlib import Path @@ -10,7 +9,7 @@ from PIL import Image from ..io.image import _read_png_16 -from .utils import verify_str_arg +from .utils import verify_str_arg, _read_pfm from .vision import VisionDataset @@ -472,31 +471,3 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name): # For consistency with other datasets, we convert to numpy return flow.numpy(), valid_flow_mask.numpy() - - -def _read_pfm(file_name): - """Read flow in .pfm format""" - - with open(file_name, "rb") as f: - header = f.readline().rstrip() - if header != b"PF": - raise ValueError("Invalid PFM file") - - dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) - if not dim_match: - raise Exception("Malformed PFM header.") - w, h = (int(dim) for dim in dim_match.groups()) - - scale = float(f.readline().rstrip()) - if scale < 0: # little-endian - endian = "<" - scale = -scale - else: - endian = ">" # big-endian - - data = np.fromfile(f, dtype=endian + "f") - - data = data.reshape(h, w, 3).transpose(2, 0, 1) - data = np.flip(data, axis=1) # flip on h dimension - data = data[:2, :, :] - return data.astype(np.float32) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index af6e1a972c2..b14f25e986b 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -18,6 +18,7 @@ from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator from urllib.parse import urlparse +import numpy as np import requests import torch from torch.utils.model_zoo import tqdm @@ -483,3 +484,39 @@ def verify_str_arg( raise ValueError(msg) return value + + +def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: + """Read file in .pfm format. Might contain either 1 or 3 channels of data. + + Args: + file_name (str): Path to the file. + slice_channels (int): Number of channels to slice out of the file. + Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc. + """ + + with open(file_name, "rb") as f: + header = f.readline().rstrip() + if header not in [b"PF", b"Pf"]: + raise ValueError("Invalid PFM file") + + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) + if not dim_match: + raise Exception("Malformed PFM header.") + w, h = (int(dim) for dim in dim_match.groups()) + + scale = float(f.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(f, dtype=endian + "f") + + pfm_channels = 3 if header == b"PF" else 1 + + data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1) + data = np.flip(data, axis=1) # flip on h dimension + data = data[:slice_channels, :, :] + return data.astype(np.float32)